Name | Modified | Size | Downloads / Week |
---|---|---|---|
Parent folder | |||
JAX v0.7.0 source code.tar.gz | 2025-07-22 | 17.3 MB | |
JAX v0.7.0 source code.zip | 2025-07-22 | 18.5 MB | |
README.md | 2025-07-22 | 3.5 kB | |
Totals: 3 Items | 35.8 MB | 4 |
- New features:
- Added
jax.P
which is an alias forjax.sharding.PartitionSpec
. -
Added
jax.tree.reduce_associative
. -
Breaking changes:
- JAX is migrating from GSPMD to Shardy by default. See the migration guide for more information.
- JAX autodiff is switching to using direct linearization by default (instead of implementing linearization via JVP and partial eval). See migration guide for more information.
jax.stages.OutInfo
has been replaced withjax.ShapeDtypeStruct
.jax.jit
now requiresfun
to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in an error starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.- The minimum Python version is now 3.11. 3.11 will remain the minimum supported version until July 2026.
- Layout API renames:
Layout
,.layout
,.input_layouts
and.output_layouts
have been renamed toFormat
,.format
,.input_formats
and.output_formats
DeviceLocalLayout
,.device_local_layout
have been renamed toLayout
and.layout
jax.experimental.shard
module has been deleted and all the APIs have been moved to thejax.sharding
endpoint. So usejax.sharding.reshard
,jax.sharding.auto_axes
andjax.sharding.explicit_axes
instead of their experimental endpoints.lax.infeed
andlax.outfeed
were removed, after being deprecated in JAX 0.6. Thetransfer_to_infeed
andtransfer_from_outfeed
methods were also removed theDevice
objects.- The
jax.extend.core.primitives.pjit_p
primitive has been renamed tojit_p
, and itsname
attribute has changed from"pjit"
to"jit"
. This affects the string representations of jaxprs. The same primitive is no longer exported from thejax.experimental.pjit
module. -
The (undocumented) function
jax.extend.backend.add_clear_backends_callback
has been removed. Users should usejax.extend.backend.register_backend_cache
instead. -
Deprecations:
- {obj}
jax.dlpack.SUPPORTED_DTYPES
is deprecated; please use the newjax.dlpack.is_supported_dtype
function. jax.scipy.special.sph_harm
has been deprecated following a similar deprecation in SciPy; usejax.scipy.special.sph_harm_y
instead.- From {mod}
jax.interpreters.xla
, the previously deprecated symbolsabstractify
andpytype_aval_mappings
have been removed. jax.interpreters.xla.canonicalize_dtype
is deprecated. For canonicalizing dtypes, preferjax.dtypes.canonicalize_dtype
. For checking whether an object is a valid jax input, preferjax.core.valid_jaxtype
.- From {mod}
jax.core
, the previously deprecated symbolsAxisName
,ConcretizationTypeError
,axis_frame
,call_p
,closed_call_p
,get_type
,trace_state_clean
,typematch
, andtypecheck
have been removed. - From {mod}
jax.lib.xla_client
, the previously deprecated symbolsDeviceAssignment
,get_topology_for_devices
, andmlir_api_version
have been removed. jax.extend.ffi
was removed after being deprecated in v0.5.0. Use {mod}jax.ffi
instead.jax.lib.xla_bridge.get_compile_options
is deprecated, and replaced byjax.extend.backend.get_compile_options
.