Flax
Flax is a neural network library for JAX
...Its design separates pure computation from state by threading parameter collections and RNGs explicitly, enabling reproducibility, transformation, and easy experimentation with JAX transforms like jit, pmap, and vmap. Modules define parameterized computations, but initialization and application remain side-effect free, which pairs naturally with JAX’s staging and compilation model. Flax emphasizes composability: optimizers, training loops, and checkpointing are provided as examples or utilities rather than monolithic frameworks, encouraging research-friendly customization. The library is widely used in vision, language, and reinforcement learning, often serving as a thin layer atop NumPy-like JAX primitives. ...