Neural Tangents is a high-level neural network API for specifying complex, hierarchical models at both finite and infinite width, built in Python on top of JAX and XLA. It lets researchers define architectures from familiar building blocks—convolutions, pooling, residual connections, and nonlinearities—and obtain not only the finite network but also the corresponding Gaussian Process (GP) kernel of its infinite-width limit. With a single specification, you can compute NNGP and NTK kernels, perform exact GP inference, and study training dynamics analytically for infinitely wide networks. The library closely mirrors JAX’s stax API while extending it to return a kernel_fn alongside init_fn and apply_fn, enabling drop-in workflows for kernel computation. Kernel evaluation is highly optimized for speed and memory, and computations can be automatically distributed across accelerators with near-linear scaling.
Features
- Unified API yielding init/apply functions and the GP kernel function
- Supports NNGP and NTK computation for Bayesian and gradient-descent limits
- Optimized kernel evaluation on CPU, GPU, and TPU with auto distribution
- Drop-in stax-style layers plus extras like LayerNorm and circular padding
- Tools for linearization, empirical kernels, and training-dynamics analysis
- Colab notebooks, examples, and tests for quick onboarding and research
 
     
    