Download Latest Version JAX v0.7.0 source code.tar.gz (17.3 MB)
Email in envelope

Get an email when there's a new version of JAX

Home / jax-v0.7.0
Name Modified Size InfoDownloads / Week
Parent folder
JAX v0.7.0 source code.tar.gz 2025-07-22 17.3 MB
JAX v0.7.0 source code.zip 2025-07-22 18.5 MB
README.md 2025-07-22 3.5 kB
Totals: 3 Items   35.8 MB 4
  • New features:
  • Added jax.P which is an alias for jax.sharding.PartitionSpec.
  • Added jax.tree.reduce_associative.

  • Breaking changes:

  • JAX is migrating from GSPMD to Shardy by default. See the migration guide for more information.
  • JAX autodiff is switching to using direct linearization by default (instead of implementing linearization via JVP and partial eval). See migration guide for more information.
  • jax.stages.OutInfo has been replaced with jax.ShapeDtypeStruct.
  • jax.jit now requires fun to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in an error starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.
  • The minimum Python version is now 3.11. 3.11 will remain the minimum supported version until July 2026.
  • Layout API renames:
    • Layout, .layout, .input_layouts and .output_layouts have been renamed to Format, .format, .input_formats and .output_formats
    • DeviceLocalLayout, .device_local_layout have been renamed to Layout and .layout
  • jax.experimental.shard module has been deleted and all the APIs have been moved to the jax.sharding endpoint. So use jax.sharding.reshard, jax.sharding.auto_axes and jax.sharding.explicit_axes instead of their experimental endpoints.
  • lax.infeed and lax.outfeed were removed, after being deprecated in JAX 0.6. The transfer_to_infeed and transfer_from_outfeed methods were also removed the Device objects.
  • The jax.extend.core.primitives.pjit_p primitive has been renamed to jit_p, and its name attribute has changed from "pjit" to "jit". This affects the string representations of jaxprs. The same primitive is no longer exported from the jax.experimental.pjit module.
  • The (undocumented) function jax.extend.backend.add_clear_backends_callback has been removed. Users should use jax.extend.backend.register_backend_cache instead.

  • Deprecations:

  • {obj}jax.dlpack.SUPPORTED_DTYPES is deprecated; please use the new jax.dlpack.is_supported_dtype function.
  • jax.scipy.special.sph_harm has been deprecated following a similar deprecation in SciPy; use jax.scipy.special.sph_harm_y instead.
  • From {mod}jax.interpreters.xla, the previously deprecated symbols abstractify and pytype_aval_mappings have been removed.
  • jax.interpreters.xla.canonicalize_dtype is deprecated. For canonicalizing dtypes, prefer jax.dtypes.canonicalize_dtype. For checking whether an object is a valid jax input, prefer jax.core.valid_jaxtype.
  • From {mod}jax.core, the previously deprecated symbols AxisName, ConcretizationTypeError, axis_frame, call_p, closed_call_p, get_type, trace_state_clean, typematch, and typecheck have been removed.
  • From {mod}jax.lib.xla_client, the previously deprecated symbols DeviceAssignment, get_topology_for_devices, and mlir_api_version have been removed.
  • jax.extend.ffi was removed after being deprecated in v0.5.0. Use {mod}jax.ffi instead.
  • jax.lib.xla_bridge.get_compile_options is deprecated, and replaced by jax.extend.backend.get_compile_options.
Source: README.md, updated 2025-07-22