Constant
parax.constant.AbstractConstant
Bases: Module, Generic[T]
An abstract interface and structural tag for a constant node.
This class is primarily used as a type-check for parax.is_constant
and parax.is_free_array to facilitate PyTree partitioning.
Note: This interface only provides the structural tagging required
by Parax. It does not automatically apply jax.lax.stop_gradient during
computations. Concrete implementations (like parax.Frozen and parax.Fixed)
handle the actual JAX-level gradient stopping and unwrapping logic.
as_free()
abstractmethod
Return the underlying value, stripping the constant tag.
Source code in parax/constant.py
29 30 31 32 | |