Skip to content

Constant

parax.constant.AbstractConstant

Bases: Module, Generic[T]

An abstract interface and structural tag for a constant node.

This class is primarily used as a type-check for parax.is_constant and parax.is_free_array to facilitate PyTree partitioning.

Note: This interface only provides the structural tagging required by Parax. It does not automatically apply jax.lax.stop_gradient during computations. Concrete implementations (like parax.Frozen and parax.Fixed) handle the actual JAX-level gradient stopping and unwrapping logic.

as_free() abstractmethod

Return the underlying value, stripping the constant tag.

Source code in parax/constant.py
29
30
31
32
@abstractmethod
def as_free(self) -> T:
    """Return the underlying value, stripping the constant tag."""
    pass