Senior Research Engineer, JAX at AssemblyAI. Location Information: United States. AssemblyAI is an applied artificial intelligence company. We use the latest deep learning technology to build practical products that bring futuristic ideas to life.. Our team includes researchers, engineers, and designers that have worked at some of the largest technology companies all over the world. Our main office is located in downtown San Francisco.. At AssemblyAI, we believe that cutting edge artificial intelligence technology should not be limited to only those with the funding or resources to invest in it.. Our goal is to help make creative, new ideas possible by making AI technology accessible to everyone through easy to use products, whether you are an independent developer, startup, or global company.. Maintain and evolve the JAX training framework for scalability and efficiency in large-scale distributed training runs.. Optimize production JAX inference systems for speech-to-text models using advanced techniques like continuous batching, model sharding, paged attention, and quantization.. Refactor and modernize model architectures and infrastructure, translating research prototypes into production-ready systems.. Investigate and resolve performance bottlenecks across the stack, from low-level kernels (XLA, Pallas) to high-level system design.. Design and deploy scalable, distributed workloads optimized for TPU and GPU architectures.. Bridge Research and Engineering teams to ensure seamless knowledge transfer and alignment on technical priorities.. Expert-level proficiency with JAX and its ecosystem (Flax, Optax, XLA compilation pipeline).. Strong experience optimizing inference systems for production, ideally with LLMs or speech models.. Hands-on experience with TPU programming and optimization; GPU/CUDA expertise is also valuable.. Passion for refactoring and improving existing systems to make code faster, cleaner, and more maintainable.. Familiarity with modern inference optimization techniques: continuous batching, KV-cache management, sharding strategies, quantization.. Domain knowledge in Speech-to-Text (ASR architectures, audio processing, streaming inference) is a plus.. Strong Python skills; C++ or Rust experience for kernel-level work is a plus.. Deep understanding of distributed training at scale and ML infrastructure best practices.. Excellent communication skills and a collaborative mindset to clearly explain complex tradeoffs and prioritize high-impact work.. Pay range:. $190K - $248K
Senior Research Engineer, JAX at AssemblyAI