Bounds
parax.bounds.AbstractBounded
Bases: Module, Generic[T]
The abstract interface for a bounded PyTree.
Used as a type check for parax.is_bounded.
Attributes:
| Name | Type | Description |
|---|---|---|
bounds |
AbstractVar[tuple[T, T]]
|
Returns the current PyTree bounds. Each must have a matching PyTree structure as |
parax.bounds.tree_lower(tree)
Extracts the lower bounds of a potentially bounded PyTree.
Standard arrays default to (-inf, inf).
Note that this function does not allow non-array/bounded leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.bounds.AbstractBounded, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract lower bounds from. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree representing the lower bounds. |
Source code in parax/bounds.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | |
parax.bounds.tree_upper(tree)
Extracts the upper bounds of a potentially bounded PyTree.
Standard arrays default to (-inf, inf).
Note that this function does not allow non-array/bounded leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.bounds.AbstractBounded, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract upper bounds from. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree representing the upper bounds. |
Source code in parax/bounds.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | |
parax.bounds.tree_bounds(tree)
Extracts two PyTrees (lower and upper) representing the boundaries.
Standard arrays default to (-inf, inf).
Note that this function does not allow non-array/bounded leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.bounds.AbstractBounded, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract bounds from. |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree]
|
A tuple of two PyTrees |
Source code in parax/bounds.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | |
parax.bounds.is_dynamic(x)
Identifies parameters that should be updated during bounded inference.
This function acts as the primary filter for eqx.partition, determining
which nodes are routed to the dynamic (differentiable/optimizable) tree
and which are left behind in the static tree.
Because parax.probability.is_leaf protects unwrappable nodes from being
split open, this function captures those nodes completely whole, allowing
them to be safely unwrapped after partitioning. Therefore, if you would
like to pass the full, wrapped nodes through a jit boundary, you should
include additional conditions or partitioning steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Any leaf node in the PyTree (as defined by |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
True if the node is meant for the inference engine. Matches: 1. Standard JAX inexact arrays (floating-point tensors). 2. Entire unwrappable bounded nodes. |
|
Note |
Explicitly returns False for |
|
|
them into the static tree. |
Source code in parax/bounds.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | |
parax.bounds.is_leaf(x)
Defines the tree traversal boundaries for bounded partitioning.
In the Parax ecosystem, certain custom nodes (like unwrappable bounded nodes) contain internal metadata. If Equinox traverses inside these nodes, it will strip their differentiable arrays away from their metadata, causing structural mismatches during recombination.
This function tells JAX/Equinox to treat these specific Parax objects as opaque, indivisible leaves.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Any node encountered during PyTree traversal. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
True if the node should NOT be traversed into. Matches: 1. Unwrappable bounded nodes (preserves their wrapper structure). 2. Constant nodes (protects static configuration objects). |
Source code in parax/bounds.py
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | |