Flax
Flax is a neural network library for JAX
Flax is a flexible neural-network library for JAX that embraces functional programming while offering ergonomic module abstractions. Its design separates pure computation from state by threading parameter collections and RNGs explicitly, enabling reproducibility, transformation, and easy experimentation with JAX transforms like jit, pmap, and vmap. Modules define parameterized computations, but initialization and application remain side-effect free, which pairs naturally with JAX’s staging...