Wrappers
parax.AbstractUnwrappable
Bases: Module, Generic[T]
An abstract class representing a deferred or wrapped PyTree node.
Unwrappables act as placeholders within a PyTree. When parax.unwrap
is called on the tree, these nodes execute custom logic (such as delayed
computation, parameter injection, or gradient stopping) and replace
themselves with their output.
unwrap()
abstractmethod
Evaluates and returns the underlying unwrapped PyTree node.
Returns:
| Type | Description |
|---|---|
T
|
The resolved underlying value, assuming no wrapped subnodes exist. |
Source code in parax/wrappers.py
37 38 39 40 41 42 43 44 | |
parax.AbstractWrappable
Bases: Module, Generic[T]
An abstract class representing a PyTree node capable of wrapping another tree.
This interface is the counterpart to parax.AbstractUnwrappable. It is
typically used to define how an object should reconstruct or "re-wrap" a
tree that was previously unwrapped.
wrap(tree)
abstractmethod
Wraps the provided tree inside this node's structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The unwrapped tree or raw value to be wrapped. |
required |
Returns:
| Type | Description |
|---|---|
Self
|
A new instance of this node containing the wrapped tree. |
Source code in parax/wrappers.py
196 197 198 199 200 201 202 203 204 205 206 | |
parax.unwrap(tree, only_if=None, cascade=True)
Recursively resolves AbstractUnwrappable nodes within a PyTree.
By default, unwrapping is performed inside-out (bottom-up) across the entire
tree. Every AbstractUnwrappable node is replaced by the result of its
unwrap() method.
If the only_if predicate is provided, unwrapping is conditionally gated.
The cascade parameter controls how descendants of matching nodes are handled.
Behavior with only_if:
1. If cascade=True (default): The tree is searched top-down. Once a
node satisfies only_if, that node and ALL of its AbstractUnwrappable
descendants are fully resolved.
2. If cascade=False: The tree is traversed bottom-up. Unwrapping ONLY
triggers for specific nodes that satisfy the condition. Unmatching
descendants are left wrapped.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
Union[AbstractUnwrappable[T] | T]
|
The PyTree to unwrap. |
required |
only_if
|
Callable[[Any], bool]
|
An optional predicate function |
None
|
cascade
|
bool
|
If True, unwrapping cascades to all descendants of a matched node. If False, only nodes strictly evaluating to True are unwrapped. |
True
|
Returns:
| Type | Description |
|---|---|
T
|
A new PyTree with the appropriate |
Source code in parax/wrappers.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | |
parax.unwrap_self(method)
Decorator for Equinox module methods.
Unwraps self before executing the method so all ties/deferred
parameters are resolved.
Source code in parax/wrappers.py
156 157 158 159 160 161 162 163 164 165 166 167 | |
parax.wrap(template_tree, unwrapped_tree, only_if=None)
Recursively resolves AbstractWrappable nodes to reconstruct a wrapped PyTree.
This function maps across a template_tree and an unwrapped_tree simultaneously.
Wrapping is performed outside-in (top-down), perfectly inverting the inside-out
(bottom-up) process of parax.unwrap.
Note: This function is typically used to re-wrap a PyTree that was previously
unwrapped via parax.unwrap and parax.AbstractUnwrappable.
If the only_if predicate is provided, the wrapping process is conditionally gated.
The tree is searched top-down, and wrapping only triggers for subtrees that
satisfy the condition. Once a node satisfies only_if, that entire subtree
is fully wrapped.
Behavior with only_if:
1. If only_if(node) is True: The node and all of its AbstractWrappable
descendants are fully wrapped.
2. If only_if(node) is False: The node itself bypasses wrapping, but the
search continues recursively into its children.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
template_tree
|
PyTree
|
The original (or blueprint) PyTree containing |
required |
unwrapped_tree
|
PyTree
|
The PyTree containing the raw, unwrapped values. |
required |
only_if
|
Callable[[Any], bool]
|
An optional predicate function |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
A new PyTree where the appropriate values from |
PyTree
|
wrapped using the template nodes. |
Source code in parax/wrappers.py
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 | |
parax.Freeze(tree)
Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractConstant[T]
Applies jax.lax.stop_gradient to all array-like leaves before unwrapping.
Implements the AbstractConstant interface so it can be filtered out
during optimization partitioning.
Corner Case Note: The __init__ automatically prevents double-wrapping.
If you pass a Freeze object into Freeze, it safely absorbs it rather
than nesting them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The PyTree to freeze. |
required |
Source code in parax/wrappers.py
368 369 370 371 372 373 374 375 | |
parax.Parameterize(fn, *args, **kwargs)
Bases: AbstractUnwrappable[T]
Unwrap into an arbitrary object by calling a function with arguments.
Useful for injecting dynamic generation (like neural network outputs or complex parametrizations) into an otherwise static PyTree structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
The callable to execute upon unwrapping. |
required |
*args
|
Any
|
Positional arguments to pass to |
()
|
**kwargs
|
Any
|
Keyword arguments to pass to |
{}
|
Source code in parax/wrappers.py
302 303 304 305 306 307 308 309 310 311 | |
parax.Apply(fn, tree, *args, **kwargs)
Bases: AbstractUnwrappable[T]
Unwrap a PyTree by applying a function to its array-like leaves.
Corner Case Note: This relies on eqx.is_array_like. Non-array
leaves (e.g., strings, standard integers, metadata) inside tree
will be bypassed and left intact, while any array-like objects
(including booleans) while be mapped.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The target PyTree to map the computation over. |
required |
fn
|
Callable
|
The function to apply to each array-like leaf in the tree. |
required |
*args
|
Any
|
Positional arguments passed to |
()
|
**kwargs
|
Any
|
Keyword arguments passed to |
{}
|
Source code in parax/wrappers.py
332 333 334 335 336 337 338 339 340 341 342 343 | |
parax.Static(tree)
Bases: AbstractUnwrappable[T], AbstractWrappable[T]
Wraps a tree and marks it as static.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The PyTree to freeze. |
required |
Source code in parax/wrappers.py
477 478 479 480 481 482 483 484 | |
parax.Tie(tree, target, source, tie_fn=lambda x: x)
Bases: AbstractUnwrappable
A wrapper that ties subtrees/parameters together.
Upon initialization, any tied sources are replaced with placeholders. Then, during unwrap, values are fetched from the target tree and injected into the source tree.
Note that if when tieing and existing Tie, the paths referring to the underlying, untied model.
Attributes:
| Name | Type | Description |
|---|---|---|
tree |
Any
|
The underlying Equinox module or PyTree. |
ties |
tuple
|
A static tuple of parameter ties, formatted as
|
Example
class Gaussian(eqx.Module): ... mu: jax.Array ... sigma: jax.Array ... model = Gaussian(mu=jnp.array(1.0), sigma=jnp.array(1.0))
Tie sigma to always be 2x mu (
tie_fncan also be left out for identity)tied_model = Tie( ... tree=model, ... target=lambda m: m.sigma, ... source=lambda m: m.mu, ... tie_fn=lambda mu: mu * 2.0 ... )
Optimizers will now only see
muopt_state = optax.sgd(0.1).init(eqx.filter(tied_model, eqx.is_inexact_array))
Unwrapping resolves the tie dynamically
active_model = unwrap(tied_model) print(active_model.sigma) # Output: 2.0
Initializes the Tie wrapper and strips the target parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
Any
|
The root PyTree or Equinox module to wrap. |
required |
target
|
Callable[[Any], Any]
|
A callable (lens) that extracts the parameter to be replaced
(e.g., |
required |
source
|
Callable[[Any], Any]
|
A callable (lens) that extracts the parameter to draw values
from (e.g., |
required |
tie_fn
|
Callable[[Any], Any]
|
An optional transformation function applied to the source parameter before injecting it into the target. Defaults to the identity function. |
lambda x: x
|
Source code in parax/wrappers.py
433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 | |
parax.Bound
Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractBounded[T]
A wrapper to attach bounds to an arbitrary PyTree.
Implements parax.bounds.AbstractBounded.
Currently this wrapper does not check to ensure the leaf nodes lie within the bounds.
Attributes:
| Name | Type | Description |
|---|---|---|
tree |
T
|
The wrapped tree. |
bounds |
tuple[T, T]
|
The tree's bounds as a tuple matching its structure. |
parax.Constrain
Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractConstrainable[T]
A wrapper to attach a parax.constraints.AbstractConstraint to an arbitrary PyTree.
Assumes the original PyTree is unconstrained i.e. has array-like leaf nodes that are defined over the entire real number line.
Implements parax.constrainable.AbstractConstrainable.
Currently this wrapper does not check to ensure the leaf nodes lie within the constraints.
Attributes:
| Name | Type | Description |
|---|---|---|
tree |
T
|
The wrapped tree. |
constraint |
AbstractConstraint
|
The tree's constraint. |
parax.Probabilize(distribution, tree, constraint=None)
Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractProbabilistic[T]
A wrapper to add a probability distribution to an arbitrary PyTree.
Implements parax.probability.AbstractProbabilistic.
Attributes:
| Name | Type | Description |
|---|---|---|
distribution |
AbstractDistribution
|
The tree's associated probability distribution. |
constraint |
AbstractConstraint
|
The tree and probability distribution's constraint. If not explicitly
provided during initialization, this is automatically inferred from the
distribution using |
tree |
T
|
The wrapped tree. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
distribution
|
AbstractDistribution
|
The probability distribution to associate with the tree. |
required |
tree
|
T
|
The PyTree to be wrapped. |
required |
constraint
|
AbstractConstraint | None
|
An optional explicit constraint. If |
None
|
Source code in parax/wrappers.py
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 | |
parax.as_unwrapped(tree)
Conditionally unwraps the root node if it is an AbstractUnwrappable.
Unlike unwrap, this function does not recursively traverse the PyTree.
It only evaluates the top-level object.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
Union[T, PyTree[T]]
|
The tree or node to potentially unwrap. |
required |
Returns:
| Type | Description |
|---|---|
T
|
The unwrapped result if |
T
|
otherwise returns the original |
Source code in parax/wrappers.py
170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | |
parax.as_opaque(tree)
Returns tree wrapped in either a parax.Static or parax.Freeze module, creating one if needed.
If tree is a JAX array or a structured PyTree, it is wrapped in parax.Freeze.
If tree is an unregistered Python object (an opaque leaf e.g. a lambda), it is wrapped in parax.Static
to safely bypass JAX transformations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
Union[T | Static[T]]
|
An arbitrary PyTree, array, or Python object. |
required |
Returns:
| Type | Description |
|---|---|
Union[Freeze, Static]
|
A static or freeze version of the tree. If it is already static or freeze, returns it directly. |
Source code in parax/wrappers.py
493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 | |