Tunix is a JAX-native library for post-training large language models, bringing supervised fine-tuning, reinforcement learning–based alignment, and knowledge distillation into one coherent toolkit. It embraces JAX’s strengths—functional programming, jit compilation, and effortless multi-device execution—so experiments scale from a single GPU to pods of TPUs with minimal code changes. The library is organized around modular pipelines for data loading, rollout, optimization, and evaluation, letting practitioners swap components without rewriting the whole stack. Examples and reference configs demonstrate end-to-end runs for common model families, helping teams reproduce baselines before customizing. Tunix also leans into research ergonomics: logging, checkpointing, and metrics are built in, and the code is written to be hackable rather than monolithic. Overall it aims to shorten the path from an off-the-shelf base model to a well-aligned, task-ready model using scalable JAX primitives.
Features
- Supervised fine-tuning, RL-style alignment, and distillation pipelines
- JAX-first design with jit/pmap for GPUs and TPUs
- Modular components for data, rollout, optimization, and eval
- Example configs and scripts for quick, reproducible baselines
- Built-in logging, checkpointing, and metric tracking
- Flax/NNX-friendly APIs that are easy to extend for research