RLax (pronounced “relax”) is a JAX-based library developed by Google DeepMind that provides reusable mathematical building blocks for constructing reinforcement learning (RL) agents. Rather than implementing full algorithms, RLax focuses on the core functional operations that underpin RL methods—such as computing value functions, returns, policy gradients, and loss terms—allowing researchers to flexibly assemble their own agents. It supports both on-policy and off-policy learning, as well as value-based, policy-based, and model-based approaches. RLax is fully JIT-compilable with JAX, enabling high-performance execution across CPU, GPU, and TPU backends. The library implements tools for Bellman equations, return distributions, general value functions, and policy optimization in both continuous and discrete action spaces. It integrates seamlessly with DeepMind’s Haiku (for neural network definition) and Optax (for optimization), making it a key component in modular RL pipelines.
Features
- Modular reinforcement learning primitives (values, returns, and policies)
- JAX-optimized for GPU/TPU acceleration and automatic differentiation
- Supports on-policy and off-policy learning paradigms
- Implements distributional value functions and general value functions
- Integrates with Haiku and Optax for neural network and optimization pipelines
- Comprehensive testing and examples for reproducibility and educational use