Skip to content

Constants

parax.constants.AbstractConstant

Bases: Module, Generic[T]

An abstract interface and structural tag for a constant node.

This class is primarily used as a type-check for parax.is_constant and parax.is_free_array to facilitate PyTree partitioning.

Note: This interface only provides the structural tagging required by Parax. It does not automatically apply jax.lax.stop_gradient during computations. Concrete implementations (like parax.Freeze and parax.Fixed) handle the actual JAX-level gradient stopping and unwrapping logic.

free() abstractmethod

Return the underlying value, stripping the constant tag.

Source code in parax/constants.py
29
30
31
32
@abstractmethod
def free(self) -> T:
    """Return the underlying value, stripping the constant tag."""
    pass

parax.as_free(value)

Returns a freed version of value by stripping any constant wrappers.

If value implements AbstractConstant, this calls value.as_free(). Otherwise, it acts as a safe no-op and returns value unchanged. This makes it safe to use directly within a jax.tree_map over mixed trees.

Parameters:

Name Type Description Default
value Union[AbstractConstant[T], T]

An arbitrary value, potentially wrapped in an AbstractConstant.

required

Returns:

Type Description
T

The freed parameter, or the original value if it was not fixed.

Source code in parax/constants.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def as_free(value: Union[AbstractConstant[T], T]) -> T:
    """
    Returns a freed version of `value` by stripping any constant wrappers.

    If `value` implements `AbstractConstant`, this calls `value.as_free()`.
    Otherwise, it acts as a safe no-op and returns `value` unchanged. This
    makes it safe to use directly within a `jax.tree_map` over mixed trees.

    Args:
        value: An arbitrary value, potentially wrapped in an `AbstractConstant`.

    Returns:
        The freed parameter, or the original value if it was not fixed.
    """    
    if is_constant(value):
        return value.free()
    return value