Tunix
A JAX-native LLM Post-Training Library
Tunix is a JAX-native library for post-training large language models, bringing supervised fine-tuning, reinforcement learning–based alignment, and knowledge distillation into one coherent toolkit. It embraces JAX’s strengths—functional programming, jit compilation, and effortless multi-device execution—so experiments scale from a single GPU to pods of TPUs with minimal code changes. The library is organized around modular pipelines for data loading, rollout, optimization, and evaluation, letting practitioners swap components without rewriting the whole stack. Examples and reference configs demonstrate end-to-end runs for common model families, helping teams reproduce baselines before customizing. ...