Constraints
parax.constraints.AbstractConstraint
Bases: Module
The base class for all physical constraints in Parax.
Constraints are a higher-level concept that provide bounds and bijectors over constrained domains. This is useful for use with unconstrained solvers (which require a bijector from the unconstrained real line to the constrained domain) and bounded solvers (which accept lower and upper bounds directly).
Attributes:
| Name | Type | Description |
|---|---|---|
bounds |
AbstractVar[tuple[PyTree, PyTree]]
|
A tuple containing the physical lower and upper bounds of the constrained space. |
bijector |
AbstractVar[AbstractBijector]
|
A |
base_bounds |
AbstractVar[tuple[PyTree, PyTree]]
|
A tuple containing the foundational, un-skewed orthogonal bounds. For primitive
constraints, this equals |
base_bijector |
AbstractVar[AbstractBijector]
|
A |
clip(value)
Clip a value to lie within this constraint.
Source code in parax/constraints.py
71 72 73 74 75 | |
is_outside(value)
Returns if another value is outside the constraint.
Source code in parax/constraints.py
77 78 79 80 81 82 | |
midpoint()
Returns the midpoint of the constraint.
Note that non-finite constraints may return infinity.
Source code in parax/constraints.py
84 85 86 87 88 89 90 | |
parax.constraints.AbstractConstrained
Bases: AbstractBounded[T]
The abstract interface for a constrained PyTree.
Used as a type check for parax.is_constrained.
Implies that the PyTree has associated constraints (and therefore bounds), but does not necessarily enforce that the PyTree follows those constraints.
Attributes:
| Name | Type | Description |
|---|---|---|
constraint |
AbstractVar[AbstractConstraint]
|
Returns the active constraint of the PyTree. |
bounds |
AbstractVar[tuple[T, T]]
|
Returns the current PyTree bounds. Each must have a matching PyTree structure as |
parax.constraints.AbstractConstrainable
Bases: AbstractConstrained[T]
The abstract interface for a constrainable PyTree.
Variables implementing this interface support the dynamic injection and updating of constraints.
Used as a type check for parax.is_constrainable.
constrain(constraint)
abstractmethod
Returns a new instance of the PyTree with the updated constraint, ensuring internal state (like unconstrained raw values) is recalculated if necessary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
constraint
|
AbstractConstraint
|
The new constraint to apply. |
required |
Returns:
| Type | Description |
|---|---|
Self
|
A new instance of the constrainable PyTree. |
Source code in parax/constraints.py
577 578 579 580 581 582 583 584 585 586 587 588 589 590 | |
parax.constraints.RealLine(shape=())
Bases: AbstractUncorrelatedConstraint
Represents a value that can span the entire real number line.
Effectively a structural no-op constraint using an Identity bijector, useful for maintaining consistent types in mixed parameter sets.
Attributes:
| Name | Type | Description |
|---|---|---|
shape |
Any
|
The expected shape of the unconstrained parameter. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape
|
Any
|
The expected shape of the unconstrained parameter. |
()
|
Source code in parax/constraints.py
129 130 131 132 133 134 | |
parax.constraints.GreaterThan(lower)
Bases: AbstractUncorrelatedConstraint
Represents a value strictly greater than a lower bound.
Attributes:
| Name | Type | Description |
|---|---|---|
lower |
ndarray
|
The exclusive lower bound array or scalar. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lower
|
Union[float, Array]
|
The exclusive lower bound. |
required |
Source code in parax/constraints.py
157 158 159 160 161 162 | |
parax.constraints.LessThan(upper)
Bases: AbstractUncorrelatedConstraint
Represents a value strictly less than an upper bound.
Attributes:
| Name | Type | Description |
|---|---|---|
upper |
ndarray
|
The exclusive upper bound array or scalar. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
upper
|
Union[float, Array]
|
The exclusive upper bound. |
required |
Source code in parax/constraints.py
184 185 186 187 188 189 | |
parax.constraints.Interval(lower, upper)
Bases: AbstractUncorrelatedConstraint
Represents a value strictly bounded between a lower and upper value.
Attributes:
| Name | Type | Description |
|---|---|---|
lower |
ndarray
|
The exclusive lower bound. |
upper |
ndarray
|
The exclusive upper bound. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lower
|
Union[float, Array]
|
The exclusive lower bound. |
required |
upper
|
Union[float, Array]
|
The exclusive upper bound. |
required |
Source code in parax/constraints.py
219 220 221 222 223 224 225 226 | |
parax.constraints.Positive(shape=(), dtype=None)
Bases: GreaterThan
Convenience constraint for values that must be strictly positive (> 0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape
|
Any
|
The shape of the parameter array. |
()
|
dtype
|
Any
|
The JAX data type of the parameter array. |
None
|
Source code in parax/constraints.py
244 245 246 247 248 249 250 | |
parax.constraints.Negative(shape=(), dtype=None)
Bases: LessThan
Convenience constraint for values that must be strictly negative (< 0).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape
|
Any
|
The shape of the parameter array. |
()
|
dtype
|
Any
|
The JAX data type of the parameter array. |
None
|
Source code in parax/constraints.py
255 256 257 258 259 260 261 | |
parax.constraints.Leafwise(tree)
Bases: AbstractConstraint
Represents a PyTree of constraints mapping over a PyTree of inputs.
Useful for applying heterogeneous constraints to complex nested structures
(like equinox.Module instances) simultaneously.
Attributes:
| Name | Type | Description |
|---|---|---|
tree |
PyTree[AbstractConstraint]
|
The PyTree containing |
Source code in parax/constraints.py
319 320 321 322 323 324 325 326 | |
parax.constraints.Custom(bijector, bounds=(jnp.array(-jnp.inf), jnp.array(jnp.inf)), base_bounds=None, base_bijector=None)
Bases: AbstractConstraint
An escape hatch for power users who need a specific distreqx bijector mapping with predefined physical bounds.
Attributes:
| Name | Type | Description |
|---|---|---|
bijector |
AbstractBijector
|
The internal, user-defined distreqx bijector mapping from the unconstrained real line to the physical space. |
bounds |
tuple[Array, Array]
|
The manually defined physical boundaries |
base_bounds |
tuple[Array, Array]
|
The orthogonal base boundaries. Defaults to |
base_bijector |
AbstractBijector
|
The bijector mapping from |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bijector
|
AbstractBijector
|
The custom |
required |
bounds
|
tuple[Array, Array]
|
A tuple of |
(array(-inf), array(inf))
|
base_bounds
|
tuple[Array, Array] | None
|
Optional. A tuple of |
None
|
base_bijector
|
AbstractBijector | None
|
Optional. The bijector handling spatial skew/correlation.
If None, defaults to |
None
|
Source code in parax/constraints.py
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 | |
parax.constraints.tree_constraints(tree)
Extracts the individual constraints of a PyTree.
Standard arrays default to parax.constraints.RealLine.
Note that this function does not allow non-array/constrainable leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.constraints.AbstractConstrainable, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract constraints from. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree representing the active constraints. |
Source code in parax/constraints.py
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 | |
parax.constraints.tree_leafwise_constraint(tree)
Extracts the single leafwise constraint of a PyTree.
Wraps the output of parax.constraints.tree_constraints
in a parax.constraints.Leafwise constraint to define
a single constraint that matches the shape of tree.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model containing probabilistic nodes or standard arrays. |
required |
Returns:
| Type | Description |
|---|---|
Leafwise
|
A single constraint whose shape matches the structure of |
Source code in parax/constraints.py
551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 | |
parax.constraints.tree_constrain(tree, constraints)
Applies a PyTree of constraints to a PyTree of constrainable PyTrees.
Standard arrays will be returned untouched if the matching constraint
is a RealLine. Attempting to apply a bounded constraint directly
to a standard array will raise an error.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to update. Must have a matching PyTree structure
to |
required |
constraints
|
PyTree
|
A PyTree of |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A new PyTree with the constraints applied. |
Source code in parax/constraints.py
606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 | |
parax.constraints.intersect(a, b)
Calculates the intersection of two constraints. Returns the most specific constraint class possible.
Source code in parax/constraints.py
642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 | |