Constants
parax.constants.AbstractConstant
Bases: Module, Generic[T]
An abstract interface and structural tag for a constant node.
This class is primarily used as a type-check for parax.is_constant
and parax.is_free_array to facilitate PyTree partitioning.
Note: This interface only provides the structural tagging required
by Parax. It does not automatically apply jax.lax.stop_gradient during
computations. Concrete implementations (like parax.Freeze and parax.Fixed)
handle the actual JAX-level gradient stopping and unwrapping logic.
free()
abstractmethod
Return the underlying value, stripping the constant tag.
Source code in parax/constants.py
29 30 31 32 | |
parax.as_free(value)
Returns a freed version of value by stripping any constant wrappers.
If value implements AbstractConstant, this calls value.as_free().
Otherwise, it acts as a safe no-op and returns value unchanged. This
makes it safe to use directly within a jax.tree_map over mixed trees.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Union[AbstractConstant[T], T]
|
An arbitrary value, potentially wrapped in an |
required |
Returns:
| Type | Description |
|---|---|
T
|
The freed parameter, or the original value if it was not fixed. |
Source code in parax/constants.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | |