Skip to content

Constraints

parax.constraints.AbstractConstraint

Bases: Module

The base class for all physical constraints in Parax.

A constraint acts as a bridge between hard physical boundaries (used by bounded optimizers or user inspection) and topological mappings (used by unconstrained ML optimizers).

Constraints may be used directly on arrays or mapped over PyTrees.

Attributes:

Name Type Description
bounds AbstractVar[tuple[PyTree, PyTree]]

A tuple containing the lower and upper bounds of the constrained space.

bijector AbstractVar[AbstractBijector]

The underlying mapping from the unconstrained real line to the bounded space.

parax.constraints.RealLine(shape=())

Bases: AbstractConstraint

Represents a value that can span the entire real number line.

Effectively a structural no-op constraint using an Identity bijector, useful for maintaining consistent types in mixed parameter sets.

Attributes:

Name Type Description
shape Any

The expected shape of the unconstrained parameter.

Parameters:

Name Type Description Default
shape Any

The expected shape of the unconstrained parameter.

()
Source code in parax/constraints.py
53
54
55
56
57
58
def __init__(self, shape: Any = ()):
    """
    Args:
        shape: The expected shape of the unconstrained parameter.
    """
    self.shape = shape

parax.constraints.GreaterThan(lower)

Bases: AbstractConstraint

Represents a value strictly greater than a lower bound.

Attributes:

Name Type Description
lower ndarray

The exclusive lower bound array or scalar.

Parameters:

Name Type Description Default
lower Union[float, Array]

The exclusive lower bound.

required
Source code in parax/constraints.py
85
86
87
88
89
90
def __init__(self, lower: Union[float, Array]):
    """
    Args:
        lower: The exclusive lower bound.
    """
    self.lower = jnp.asarray(lower, dtype=float)

parax.constraints.LessThan(upper)

Bases: AbstractConstraint

Represents a value strictly less than an upper bound.

Attributes:

Name Type Description
upper ndarray

The exclusive upper bound array or scalar.

Parameters:

Name Type Description Default
upper Union[float, Array]

The exclusive upper bound.

required
Source code in parax/constraints.py
116
117
118
119
120
121
def __init__(self, upper: Union[float, Array]):
    """
    Args:
        upper: The exclusive upper bound.
    """
    self.upper = jnp.asarray(upper, dtype=float)

parax.constraints.Interval(lower, upper)

Bases: AbstractConstraint

Represents a value strictly bounded between a lower and upper value.

Attributes:

Name Type Description
lower ndarray

The exclusive lower bound.

upper ndarray

The exclusive upper bound.

Parameters:

Name Type Description Default
lower Union[float, Array]

The exclusive lower bound.

required
upper Union[float, Array]

The exclusive upper bound.

required
Source code in parax/constraints.py
156
157
158
159
160
161
162
163
def __init__(self, lower: Union[float, Array], upper: Union[float, Array]):
    """
    Args:
        lower: The exclusive lower bound.
        upper: The exclusive upper bound.
    """
    self.lower = jnp.asarray(lower, dtype=float)
    self.upper = jnp.asarray(upper, dtype=float)

parax.constraints.Positive(shape=(), dtype=None)

Bases: GreaterThan

Convenience constraint for values that must be strictly positive (> 0).

Parameters:

Name Type Description Default
shape Any

The shape of the parameter array.

()
dtype Any

The JAX data type of the parameter array.

None
Source code in parax/constraints.py
181
182
183
184
185
186
187
def __init__(self, shape: Any = (), dtype: Any = None):
    """
    Args:
        shape: The shape of the parameter array.
        dtype: The JAX data type of the parameter array.
    """
    super().__init__(lower=jnp.zeros(shape, dtype=dtype))

parax.constraints.Negative(shape=(), dtype=None)

Bases: LessThan

Convenience constraint for values that must be strictly negative (< 0).

Parameters:

Name Type Description Default
shape Any

The shape of the parameter array.

()
dtype Any

The JAX data type of the parameter array.

None
Source code in parax/constraints.py
192
193
194
195
196
197
198
def __init__(self, shape: Any = (), dtype: Any = None):
    """
    Args:
        shape: The shape of the parameter array.
        dtype: The JAX data type of the parameter array.
    """
    super().__init__(upper=jnp.zeros(shape, dtype=dtype))

parax.constraints.TreeConstraint(constraints)

Bases: AbstractConstraint

Represents a PyTree of constraints mapping over a PyTree of inputs.

Useful for applying heterogeneous constraints to complex nested structures (like equinox.Module instances) simultaneously.

Attributes:

Name Type Description
tree PyTree[AbstractConstraint]

The PyTree containing AbstractConstraint leaves.

Parameters:

Name Type Description Default
constraints PyTree[AbstractConstraint]

A PyTree containing AbstractConstraint leaves. Non-constraint leaves are ignored.

required

Raises:

Type Description
ValueError

If the provided PyTree contains no constraint leaves.

Source code in parax/constraints.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
def __init__(
    self, 
    constraints: PyTree[AbstractConstraint],
):
    """
    Args:
        constraints: A PyTree containing `AbstractConstraint` leaves.
            Non-constraint leaves are ignored.

    Raises:
        ValueError: If the provided PyTree contains no constraint leaves.
    """
    # Local import prevents circular dependency at initialization time
    from parax.filters import is_constraint

    leaves = jax.tree.leaves(constraints, is_leaf=is_constraint)
    if not leaves:
        raise ValueError("The pytree of constraints cannot be empty.")

    self.tree = constraints

bijector property

Returns a distreqx.TreeMap bijector that applies each respective leaf constraint's bijector.

bounds property

Extracts a PyTree of lower bounds and a PyTree of upper bounds. Non-constraint nodes in the original PyTree are left unmodified.

parax.constraints.CustomConstraint(bijector, bounds=(jnp.array(-jnp.inf), jnp.array(jnp.inf)))

Bases: AbstractConstraint

An escape hatch for power users who need a specific distreqx bijector mapping with predefined physical bounds.

Attributes:

Name Type Description
_custom_bijector AbstractBijector

The internal, user-defined distreqx bijector.

_custom_bounds tuple[Array, Array]

The manually defined physical boundaries (lower, upper).

Parameters:

Name Type Description Default
bijector AbstractBijector

The custom distreqx bijector.

required
bounds tuple[Array, Array]

A tuple of (lower, upper) defining the physical boundaries of the constrained space. Defaults to (-inf, inf).

(array(-inf), array(inf))
Source code in parax/constraints.py
343
344
345
346
347
348
349
350
351
352
353
354
355
def __init__(
    self, 
    bijector: AbstractBijector, 
    bounds: tuple[Array, Array] = (jnp.array(-jnp.inf), jnp.array(jnp.inf))
):
    """
    Args:
        bijector: The custom `distreqx` bijector.
        bounds: A tuple of `(lower, upper)` defining the physical 
            boundaries of the constrained space. Defaults to `(-inf, inf)`.
    """
    self._custom_bijector = bijector
    self._custom_bounds = tuple(jnp.asarray(b) for b in bounds)