Skip to content

Constraints

parax.constraints.AbstractConstraint

Bases: Module

The base class for all physical constraints in Parax.

Constraints are a higher-level concept that provide bounds and bijectors over constrained domains. This is useful for use with unconstrained solvers (which require a bijector from the unconstrained real line to the constrained domain) and bounded solvers (which accept lower and upper bounds directly).

Attributes:

Name Type Description
bounds AbstractVar[tuple[PyTree, PyTree]]

A tuple containing the physical lower and upper bounds of the constrained space.

bijector AbstractVar[AbstractBijector]

A distreqx.bijectors.AbstractBijector mapping from the unconstrained real line to the physical space.

base_bounds AbstractVar[tuple[PyTree, PyTree]]

A tuple containing the foundational, un-skewed orthogonal bounds. For primitive constraints, this equals bounds. For transformed constraints, this isolates the safe topological box before any dense correlations or skews are applied.

base_bijector AbstractVar[AbstractBijector]

A distreqx.bijectors.AbstractBijector mapping from the orthogonal base_bounds space into the physical bounds space. Defaults to Identity unless geometric skews are present.

clip(value)

Clip a value to lie within this constraint.

Source code in parax/constraints.py
71
72
73
74
75
def clip(self, value: PyTree) -> PyTree:
    """
    Clip a value to lie within this constraint.
    """
    return jax.tree.map(jnp.clip, value, self.bounds[0], self.bounds[1])

is_outside(value)

Returns if another value is outside the constraint.

Source code in parax/constraints.py
77
78
79
80
81
82
def is_outside(self, value: PyTree) -> PyTree:
    """
    Returns if another value is outside the constraint.
    """
    lower, upper = self.bounds
    return jax.tree.map(lambda x, l, u: jnp.logical_or(x < l, x > u), value, lower, upper)

midpoint()

Returns the midpoint of the constraint.

Note that non-finite constraints may return infinity.

Source code in parax/constraints.py
84
85
86
87
88
89
90
def midpoint(self) -> PyTree:
    """
    Returns the midpoint of the constraint.

    Note that non-finite constraints may return infinity.
    """
    return jax.tree.map(lambda a, b: (a + b) / 2.0, self.bounds[0], self.bounds[1])

parax.constraints.AbstractConstrained

Bases: AbstractBounded[T]

The abstract interface for a constrained PyTree.

Used as a type check for parax.is_constrained.

Implies that the PyTree has associated constraints (and therefore bounds), but does not necessarily enforce that the PyTree follows those constraints.

Attributes:

Name Type Description
constraint AbstractVar[AbstractConstraint]

Returns the active constraint of the PyTree.

bounds AbstractVar[tuple[T, T]]

Returns the current PyTree bounds. Each must have a matching PyTree structure as self.

parax.constraints.AbstractConstrainable

Bases: AbstractConstrained[T]

The abstract interface for a constrainable PyTree.

Variables implementing this interface support the dynamic injection and updating of constraints.

Used as a type check for parax.is_constrainable.

constrain(constraint) abstractmethod

Returns a new instance of the PyTree with the updated constraint, ensuring internal state (like unconstrained raw values) is recalculated if necessary.

Parameters:

Name Type Description Default
constraint AbstractConstraint

The new constraint to apply.

required

Returns:

Type Description
Self

A new instance of the constrainable PyTree.

Source code in parax/constraints.py
577
578
579
580
581
582
583
584
585
586
587
588
589
590
@abstractmethod
def constrain(self, constraint: AbstractConstraint) -> Self:
    """
    Returns a new instance of the PyTree with the updated constraint,
    ensuring internal state (like unconstrained raw values) is 
    recalculated if necessary.

    Args:
        constraint: The new constraint to apply.

    Returns:
        A new instance of the constrainable PyTree.
    """
    raise NotImplementedError

parax.constraints.RealLine(shape=())

Bases: AbstractUncorrelatedConstraint

Represents a value that can span the entire real number line.

Effectively a structural no-op constraint using an Identity bijector, useful for maintaining consistent types in mixed parameter sets.

Attributes:

Name Type Description
shape Any

The expected shape of the unconstrained parameter.

Parameters:

Name Type Description Default
shape Any

The expected shape of the unconstrained parameter.

()
Source code in parax/constraints.py
129
130
131
132
133
134
def __init__(self, shape: Any = ()):
    """
    Args:
        shape: The expected shape of the unconstrained parameter.
    """
    self.shape = shape

parax.constraints.GreaterThan(lower)

Bases: AbstractUncorrelatedConstraint

Represents a value strictly greater than a lower bound.

Attributes:

Name Type Description
lower ndarray

The exclusive lower bound array or scalar.

Parameters:

Name Type Description Default
lower Union[float, Array]

The exclusive lower bound.

required
Source code in parax/constraints.py
157
158
159
160
161
162
def __init__(self, lower: Union[float, Array]):
    """
    Args:
        lower: The exclusive lower bound.
    """
    self.lower = jnp.asarray(lower, dtype=float)

parax.constraints.LessThan(upper)

Bases: AbstractUncorrelatedConstraint

Represents a value strictly less than an upper bound.

Attributes:

Name Type Description
upper ndarray

The exclusive upper bound array or scalar.

Parameters:

Name Type Description Default
upper Union[float, Array]

The exclusive upper bound.

required
Source code in parax/constraints.py
184
185
186
187
188
189
def __init__(self, upper: Union[float, Array]):
    """
    Args:
        upper: The exclusive upper bound.
    """
    self.upper = jnp.asarray(upper, dtype=float)

parax.constraints.Interval(lower, upper)

Bases: AbstractUncorrelatedConstraint

Represents a value strictly bounded between a lower and upper value.

Attributes:

Name Type Description
lower ndarray

The exclusive lower bound.

upper ndarray

The exclusive upper bound.

Parameters:

Name Type Description Default
lower Union[float, Array]

The exclusive lower bound.

required
upper Union[float, Array]

The exclusive upper bound.

required
Source code in parax/constraints.py
219
220
221
222
223
224
225
226
def __init__(self, lower: Union[float, Array], upper: Union[float, Array]):
    """
    Args:
        lower: The exclusive lower bound.
        upper: The exclusive upper bound.
    """
    self.lower = jnp.asarray(lower, dtype=float)
    self.upper = jnp.asarray(upper, dtype=float)

parax.constraints.Positive(shape=(), dtype=None)

Bases: GreaterThan

Convenience constraint for values that must be strictly positive (> 0).

Parameters:

Name Type Description Default
shape Any

The shape of the parameter array.

()
dtype Any

The JAX data type of the parameter array.

None
Source code in parax/constraints.py
244
245
246
247
248
249
250
def __init__(self, shape: Any = (), dtype: Any = None):
    """
    Args:
        shape: The shape of the parameter array.
        dtype: The JAX data type of the parameter array.
    """
    super().__init__(lower=jnp.zeros(shape, dtype=dtype))

parax.constraints.Negative(shape=(), dtype=None)

Bases: LessThan

Convenience constraint for values that must be strictly negative (< 0).

Parameters:

Name Type Description Default
shape Any

The shape of the parameter array.

()
dtype Any

The JAX data type of the parameter array.

None
Source code in parax/constraints.py
255
256
257
258
259
260
261
def __init__(self, shape: Any = (), dtype: Any = None):
    """
    Args:
        shape: The shape of the parameter array.
        dtype: The JAX data type of the parameter array.
    """
    super().__init__(upper=jnp.zeros(shape, dtype=dtype))

parax.constraints.Leafwise(tree)

Bases: AbstractConstraint

Represents a PyTree of constraints mapping over a PyTree of inputs.

Useful for applying heterogeneous constraints to complex nested structures (like equinox.Module instances) simultaneously.

Attributes:

Name Type Description
tree PyTree[AbstractConstraint]

The PyTree containing AbstractConstraint leaves.

Source code in parax/constraints.py
319
320
321
322
323
324
325
326
def __init__(
    self, 
    tree: PyTree[AbstractConstraint],
):
    leaves = jax.tree.leaves(tree, is_leaf=is_constraint)
    if not leaves:
        raise ValueError("The pytree of `tree` cannot be empty.")
    self.tree = tree

parax.constraints.Custom(bijector, bounds=(jnp.array(-jnp.inf), jnp.array(jnp.inf)), base_bounds=None, base_bijector=None)

Bases: AbstractConstraint

An escape hatch for power users who need a specific distreqx bijector mapping with predefined physical bounds.

Attributes:

Name Type Description
bijector AbstractBijector

The internal, user-defined distreqx bijector mapping from the unconstrained real line to the physical space.

bounds tuple[Array, Array]

The manually defined physical boundaries (lower, upper).

base_bounds tuple[Array, Array]

The orthogonal base boundaries. Defaults to bounds if omitted.

base_bijector AbstractBijector

The bijector mapping from base_bounds to bounds. Defaults to Identity if omitted.

Parameters:

Name Type Description Default
bijector AbstractBijector

The custom distreqx bijector.

required
bounds tuple[Array, Array]

A tuple of (lower, upper) defining the physical boundaries of the constrained space. Defaults to (-inf, inf).

(array(-inf), array(inf))
base_bounds tuple[Array, Array] | None

Optional. A tuple of (lower, upper) defining the orthogonal base boundaries. If None, defaults to bounds.

None
base_bijector AbstractBijector | None

Optional. The bijector handling spatial skew/correlation. If None, defaults to distreqx.bijectors.Identity.

None
Source code in parax/constraints.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def __init__(
    self, 
    bijector: AbstractBijector, 
    bounds: tuple[Array, Array] = (jnp.array(-jnp.inf), jnp.array(jnp.inf)),
    base_bounds: tuple[Array, Array] | None = None,
    base_bijector: AbstractBijector | None = None
):
    """
    Args:
        bijector: The custom `distreqx` bijector.
        bounds: A tuple of `(lower, upper)` defining the physical 
            boundaries of the constrained space. Defaults to `(-inf, inf)`.
        base_bounds: Optional. A tuple of `(lower, upper)` defining the orthogonal 
            base boundaries. If None, defaults to `bounds`.
        base_bijector: Optional. The bijector handling spatial skew/correlation. 
            If None, defaults to `distreqx.bijectors.Identity`.
    """
    self.bijector = bijector
    self.bounds = tuple(jnp.asarray(b) for b in bounds)

    # Default base_bounds to physical bounds if not provided
    if base_bounds is None:
        self.base_bounds = self.bounds
    else:
        self.base_bounds = tuple(jnp.asarray(b) for b in base_bounds)

    # Default base_bijector to Identity if not provided
    if base_bijector is None:
        self.base_bijector = Identity()
    else:
        self.base_bijector = base_bijector

parax.constraints.tree_constraints(tree)

Extracts the individual constraints of a PyTree.

Standard arrays default to parax.constraints.RealLine.

Note that this function does not allow non-array/constrainable leaf nodes. If you have leaves in your tree that are neither arrays nor derive from parax.constraints.AbstractConstrainable, be sure to mark them as static or filter them out using e.g. eqx.filter first.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model to extract constraints from.

required

Returns:

Type Description
PyTree

A PyTree representing the active constraints.

Source code in parax/constraints.py
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
def tree_constraints(tree: PyTree) -> PyTree:
    """
    Extracts the individual constraints of a PyTree.

    Standard arrays default to `parax.constraints.RealLine`.

    Note that this function does not allow non-array/constrainable leaf nodes.
    If you have leaves in your tree that are neither arrays nor derive
    from `parax.constraints.AbstractConstrainable`, be sure to mark
    them as static or filter them out using e.g. `eqx.filter` first.    

    Args:
        tree: The PyTree model to extract constraints from.

    Returns:
        A PyTree representing the active constraints.
    """
    from parax.wrappers import unwrap

    def _get_constraint(x):
        if is_constrained(x):
            return unwrap(x.constraint)
        if eqx.is_inexact_array(x):
            return RealLine(shape=x.shape)
        raise ValueError(
            f"Found a leaf node of type {type(x)} that is neither constrained "
            f"nor an array in `parax.constraints.tree_constraints`. Value: {x}"
        )

    return jax.tree_util.tree_map(_get_constraint, tree, is_leaf=is_constrained)

parax.constraints.tree_leafwise_constraint(tree)

Extracts the single leafwise constraint of a PyTree.

Wraps the output of parax.constraints.tree_constraints in a parax.constraints.Leafwise constraint to define a single constraint that matches the shape of tree.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model containing probabilistic nodes or standard arrays.

required

Returns:

Type Description
Leafwise

A single constraint whose shape matches the structure of tree.

Source code in parax/constraints.py
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
def tree_leafwise_constraint(tree: PyTree) -> Leafwise:
    """
    Extracts the single leafwise constraint of a PyTree.

    Wraps the output of `parax.constraints.tree_constraints`
    in a `parax.constraints.Leafwise` constraint to define
    a single constraint that matches the shape of `tree`.

    Args:
        tree: The PyTree model containing probabilistic nodes or standard arrays.

    Returns:
        A single constraint whose shape matches the structure of `tree`.
    """
    return Leafwise(tree_constraints(tree)) 

parax.constraints.tree_constrain(tree, constraints)

Applies a PyTree of constraints to a PyTree of constrainable PyTrees.

Standard arrays will be returned untouched if the matching constraint is a RealLine. Attempting to apply a bounded constraint directly to a standard array will raise an error.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model to update. Must have a matching PyTree structure to constraints.

required
constraints PyTree

A PyTree of parax.AbstractConstraint objects.

required

Returns:

Type Description
PyTree

A new PyTree with the constraints applied.

Source code in parax/constraints.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
def tree_constrain(tree: PyTree, constraints: PyTree) -> PyTree:
    """
    Applies a PyTree of constraints to a PyTree of constrainable PyTrees.

    Standard arrays will be returned untouched if the matching constraint 
    is a `RealLine`. Attempting to apply a bounded constraint directly 
    to a standard array will raise an error.

    Args:
        tree: The PyTree model to update. Must have a matching PyTree structure 
            to `constraints`.
        constraints: A PyTree of `parax.AbstractConstraint` objects.

    Returns:
        A new PyTree with the constraints applied.
    """
    def _apply_constraint(x, c):
        if is_constrainable(x):
            return x.constrain(c)
        if eqx.is_inexact_array(x):
            if isinstance(c, RealLine):
                return x
            raise TypeError(
                "Cannot apply a bounded constraint to a raw JAX array directly. "
                "Ensure the array is wrapped in a `parax.Constrained` variable first."
            )
        raise ValueError(
            f"Found a leaf node of type {type(x)} that is neither constrainable "
            f"nor an array in `parax.constraints.tree_constrain`. Value: {x}"
        )

    return jax.tree_util.tree_map(
        _apply_constraint, tree, constraints, is_leaf=is_constrainable
    )

parax.constraints.intersect(a, b)

Calculates the intersection of two constraints. Returns the most specific constraint class possible.

Source code in parax/constraints.py
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
def intersect(a: AbstractConstraint, b: AbstractConstraint) -> AbstractConstraint:
    """
    Calculates the intersection of two constraints.
    Returns the most specific constraint class possible.
    """
    a_lower, a_upper = a.bounds
    b_lower, b_upper = b.bounds

    lower = jnp.maximum(a_lower, b_lower)
    upper = jnp.minimum(a_upper, b_upper)

    # Convert to concrete numpy arrays for boolean checks during init
    np_lower = jnp.asarray(lower)
    np_upper = jnp.asarray(upper)

    np_lower, np_upper = eqx.error_if(
        (np_lower, np_upper),
        jnp.any(jnp.greater_equal(np_lower, np_upper)),
        f"Constraint intersection is empty or invalid."
    )

    is_neginf_lower = jnp.all(jnp.isneginf(np_lower))
    is_posinf_upper = jnp.all(jnp.isposinf(np_upper))
    is_zero_lower = jnp.all(jnp.equal(np_lower, 0.0))
    is_zero_upper = jnp.all(jnp.equal(np_upper, 0.0))

    # Resolve to the most specific constraint class
    if is_neginf_lower and is_posinf_upper:
        return RealLine()
    elif is_zero_lower and is_posinf_upper:
        return Positive()
    elif is_neginf_lower and is_zero_upper:
        return Negative()
    elif is_posinf_upper:
        return GreaterThan(lower)
    elif is_neginf_lower:
        return LessThan(upper)
    else:
        return Interval(lower, upper)