Haiku is a library built on top of JAX designed to provide simple, composable abstractions for machine learning research. JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support. Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations. Haiku provides two core tools: a module abstraction, hk.Module, and a simple function transformation, hk.transform. hk.Modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs. hk.transform turns functions that use these object-oriented, functionally "impure" modules into pure functions that can be used with jax.jit, jax.grad, jax.pmap, etc.
Features
- Haiku has been tested by researchers at DeepMind at scale
- Haiku does not reinvent the wheel
- Haiku is a library, not a framework
- Transitioning to Haiku is easy
- Haiku makes other aspects of JAX simpler
- Documentation available