Name | Modified | Size | Downloads / Week |
---|---|---|---|
Parent folder | |||
JAX v0.6.0 source code.tar.gz | 2025-04-17 | 16.9 MB | |
JAX v0.6.0 source code.zip | 2025-04-17 | 18.0 MB | |
README.md | 2025-04-17 | 4.6 kB | |
Totals: 3 Items | 35.0 MB | 0 |
-
Breaking changes
-
jax.numpy.array
no longer acceptsNone
. This behavior was deprecated since November 2023 and is now removed. - Removed the
config.jax_data_dependent_tracing_fallback
config option, which was added temporarily in v0.4.36 to allow users to opt out of the new "stackless" tracing machinery. - Removed the
config.jax_eager_pmap
config option. - Disallow the calling of
lower
andtrace
AOT APIs on the result ofjax.jit
if there have been subsequent wrappers applied. Previously this worked, but silently ignored the wrappers. The workaround is to applyjax.jit
last among the wrappers, and similarly forjax.pmap
. See#27873
. -
The
cuda12_pip
extra forjax
has been removed; usepip install jax[cuda12]
instead. -
Changes
- The minimum CuDNN version is v9.8.
- JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain supported.
- JAX package extras are now updated to use dash instead of underscore to
align with PEP 685. For instance, if you were previously using
pip install jax[cuda12_local]
to install JAX, runpip install jax[cuda12-local]
instead. -
jax.jit
now requiresfun
to be passed by position, and additional arguments to be passed by keyword. Doing otherwise will result in a DeprecationWarning in v0.6.X, and an error in starting in v0.7.X. -
Deprecations
-
jax.tree_util.build_tree
is deprecated. Usejax.tree.unflatten
instead. - Implemented host callback handlers for CPU and GPU devices using XLA's FFI and removed existing CPU/GPU handlers using XLA's custom call.
- All APIs in
jax.lib.xla_extension
are now deprecated. jax.interpreters.mlir.hlo
andjax.interpreters.mlir.func_dialect
, which were accidental exports, have been removed. If needed, they are available fromjax.extend.mlir
.jax.interpreters.mlir.custom_call
is deprecated. The APIs provided byjax.ffi
should be used instead.- The deprecated use of
jax.ffi.ffi_call
with inline arguments is no longer supported.jax.ffi.ffi_call
now unconditionally returns a callable. - The following exports in
jax.lib.xla_client
are deprecated:get_topology_for_devices
,heap_profile
,mlir_api_version
,Client
,CompileOptions
,DeviceAssignment
,Frame
,HloSharding
,OpSharding
,Traceback
. - The following internal APIs in
jax.util
are deprecated:HashableFunction
,as_hashable_function
,cache
,safe_map
,safe_zip
,split_dict
,split_list
,split_list_checked
,split_merge
,subvals
,toposort
,unzip2
,wrap_name
, andwraps
. jax.dlpack.to_dlpack
has been deprecated. You can usually pass a JAXArray
directly to thefrom_dlpack
function of another framework. If you need the functionality ofto_dlpack
, use the__dlpack__
attribute of an array.jax.lax.infeed
,jax.lax.infeed_p
,jax.lax.outfeed
, andjax.lax.outfeed_p
are deprecated and will be removed in JAX v0.7.0.- Several previously-deprecated APIs have been removed, including:
- From
jax.lib.xla_client
:ArrayImpl
,FftType
,PaddingType
,PrimitiveType
,XlaBuilder
,dtype_to_etype
,ops
,register_custom_call_target
,shape_from_pyval
,Shape
,XlaComputation
. - From
jax.lib.xla_extension
:ArrayImpl
,XlaRuntimeError
. - From
jax
:jax.treedef_is_leaf
,jax.tree_flatten
,jax.tree_map
,jax.tree_leaves
,jax.tree_structure
,jax.tree_transpose
, andjax.tree_unflatten
. Replacements can be found injax.tree
orjax.tree_util
. - From
jax.core
:AxisSize
,ClosedJaxpr
,EvalTrace
,InDBIdx
,InputType
,Jaxpr
,JaxprEqn
,Literal
,MapPrimitive
,OpaqueTraceState
,OutDBIdx
,Primitive
,Token
,TRACER_LEAK_DEBUGGER_WARNING
,Var
,concrete_aval
,dedup_referents
,escaped_tracer_error
,extend_axis_env_nd
,full_lower
,get_referent
,jaxpr_as_fun
,join_effects
,lattice_join
,leaked_tracer_error
,maybe_find_leaked_tracers
,raise_to_shaped
,raise_to_shaped_mappings
,reset_trace_state
,str_eqn_compact
,substitute_vars_in_output_ty
,typecompat
, andused_axis_names_jaxpr
. Most have no public replacement, though a few are available atjax.extend.core
. - The
vectorized
argument tojax.pure_callback
andjax.ffi.ffi_call
. Use thevmap_method
parameter instead.
- From