Penzai, developed by Google DeepMind, is a JAX-based library for representing, visualizing, and manipulating neural network models as functional pytree data structures. It is designed to make machine learning research more interpretable and interactive, particularly for tasks like model surgery, ablation studies, architecture debugging, and interpretability research. Unlike conventional neural network libraries, Penzai exposes the full internal structure of models, enabling fine-grained inspection and modification after training. Its modular design includes tools for tree manipulation, named axes, and declarative neural network construction. The library integrates tightly with Treescope, an advanced pretty-printer for visualizing deeply nested JAX pytrees and NDArray structures. Penzai’s penzai.nn module provides a compositional, combinator-based API for building neural networks.
Features
- Builds and represents models as readable JAX pytrees for full structural transparency
- Provides Treescope for interactive visualization of complex model hierarchies
- Includes pytree manipulation utilities for selective editing and rewrites (pz.select)
- Supports named axes for flexible vectorization and clear computation graphs (pz.nx)
- Combinator-based neural network library (pz.nn) with state and parameter sharing
- Modular transformer implementations for Gemma, LLaMA, Mistral, and GPT-NeoX architectures