Diffrax is a numerical differential equation solving library built for the JAX ecosystem, with a strong focus on composability, differentiability, and high-performance scientific computing. The project provides tools for solving ordinary differential equations, stochastic differential equations, controlled differential equations, and related systems in a way that fits naturally into modern machine learning and differentiable programming workflows. Because it is written to work closely with JAX, it supports just-in-time compilation, automatic differentiation, vectorization, and accelerator-backed execution on hardware such as GPUs and TPUs. This makes it especially appealing for researchers who need equation solvers that can be embedded inside trainable models or simulation-heavy learning systems.
Features
- Differential equation solvers for ODEs, SDEs, and related systems
- Native integration with JAX for automatic differentiation and JIT compilation
- Support for GPU and TPU execution through the JAX backend
- Modular design for solver methods, controllers, and adjoint strategies
- Event handling and flexible time-stepping for complex simulations
- Useful for scientific computing, neural differential equations, and simulation-based learning