Download Latest Version v2.13 source code.tar.gz (4.1 MB)
Email in envelope

Get an email when there's a new version of Transformer Engine

Home / v2.11
Name Modified Size InfoDownloads / Week
Parent folder
transformer_engine_torch-2.11.0+cu12torch2.8.0+cu129cxx11abiTRUE-cp312-cp312-linux_aarch64.whl 2026-02-02 611.8 kB
transformer_engine_torch-2.11.0+cu12torch2.8.0+cu129cxx11abiTRUE-cp312-cp312-linux_x86_64.whl 2026-02-02 659.2 kB
transformer_engine_torch-2.11.0+cu13torch25.11cxx11abiTRUE-cp312-cp312-linux_aarch64.whl 2026-02-02 700.1 kB
transformer_engine_torch-2.11.0+cu13torch25.11cxx11abiTRUE-cp312-cp312-linux_x86_64.whl 2026-02-02 761.6 kB
README.md 2025-12-17 3.1 kB
v2.11 source code.tar.gz 2025-12-17 3.8 MB
v2.11 source code.zip 2025-12-17 4.3 MB
Totals: 7 Items   10.9 MB 0

Transformer Engine v2.11 Release Notes

Key Features and Enhancements

  • [PyTorch] Enabled the reference Current Scaling recipe for FP8 training. (#2368)
  • [PyTorch] Improved Random Hadamard Transform (RHT) device tensor caching to reduce memory allocations and improve performance for NVFP4 quantization. (#2395)
  • [PyTorch] Implemented selective activation checkpointing for LayerNormMLP module (#2311)
  • [C, PyTorch, JAX] Improved performance of MXFP8 quantization. (#2062)
  • [C, PyTorch] Improved performance of NVFP4 quantization. (#2351)
  • [PyTorch] Improved FSDP2 all-gather performance and added support for FusedAdam optimizer with FSDP2. (#2370)
  • [PyTorch] Extended debug tools to support GroupedLinear layers. (#1953)
  • [JAX] Added Triton kernel bindings for JAX, enabling custom Triton kernels in JAX workflows. (#2437)
  • [C] Introduced experimental NVTEGroupedTensor class and helper functions. (#2388)
  • [C, PyTorch, JAX] Added FP8 support for primary weights in MXFP8 format with partial casting and amax calculations. (#2055)
  • [JAX] Added support for context parallelism (CP) for THD format and sliding window attention (SWA) using all-gather (AG), striped load balancing with stripe size greater than 1. (#2379)
  • [JAX] Implemented JAX primitives for token permutation operations on single GPU for mixture-of-experts routing. (#2473)
  • [PyTorch] Added THD format support for max_logit clipping and MuonClip gradient clipping operations. (#2480)

Fixed Issues

  • [PyTorch] Fixed a numerical issue when noncontiguous tensor was passed to cross_entropy backward pass. (#2402)
  • [PyTorch] Fixed CUDA graph execution order for backward weight gradient computation when using chunked layers. (#2376)
  • [C] Fixed runtime library loading logic to properly handle missing dependencies and load order. (#2297)
  • [Jax] Removed use of scan loop as the default for ring attention due for improved performance (#2503).

Breaking Changes in This Release

No breaking changes in this release.

Deprecated Features

No features deprecated in this release.

Source: README.md, updated 2025-12-17