Skip to content

Wrappers

parax.AbstractUnwrappable

Bases: Module, Generic[T]

An abstract class representing a deferred or wrapped PyTree node.

Unwrappables act as placeholders within a PyTree. When parax.unwrap is called on the tree, these nodes execute custom logic (such as delayed computation, parameter injection, or gradient stopping) and replace themselves with their output.

unwrap() abstractmethod

Evaluates and returns the underlying unwrapped PyTree node.

Returns:

Type Description
T

The resolved underlying value, assuming no wrapped subnodes exist.

Source code in parax/wrappers.py
37
38
39
40
41
42
43
44
@abstractmethod
def unwrap(self) -> T:
    """Evaluates and returns the underlying unwrapped PyTree node.

    Returns:
        The resolved underlying value, assuming no wrapped subnodes exist.
    """
    pass

parax.AbstractWrappable

Bases: Module, Generic[T]

An abstract class representing a PyTree node capable of wrapping another tree.

This interface is the counterpart to parax.AbstractUnwrappable. It is typically used to define how an object should reconstruct or "re-wrap" a tree that was previously unwrapped.

wrap(tree) abstractmethod

Wraps the provided tree inside this node's structure.

Parameters:

Name Type Description Default
tree T

The unwrapped tree or raw value to be wrapped.

required

Returns:

Type Description
Self

A new instance of this node containing the wrapped tree.

Source code in parax/wrappers.py
196
197
198
199
200
201
202
203
204
205
206
@abstractmethod
def wrap(self, tree: T) -> Self:
    """Wraps the provided tree inside this node's structure.

    Args:
        tree: The unwrapped tree or raw value to be wrapped.

    Returns:
        A new instance of this node containing the wrapped tree.
    """
    pass

parax.unwrap(tree, only_if=None, cascade=True)

Recursively resolves AbstractUnwrappable nodes within a PyTree.

By default, unwrapping is performed inside-out (bottom-up) across the entire tree. Every AbstractUnwrappable node is replaced by the result of its unwrap() method.

If the only_if predicate is provided, unwrapping is conditionally gated. The cascade parameter controls how descendants of matching nodes are handled.

Behavior with only_if: 1. If cascade=True (default): The tree is searched top-down. Once a node satisfies only_if, that node and ALL of its AbstractUnwrappable descendants are fully resolved. 2. If cascade=False: The tree is traversed bottom-up. Unwrapping ONLY triggers for specific nodes that satisfy the condition. Unmatching descendants are left wrapped.

Parameters:

Name Type Description Default
tree Union[AbstractUnwrappable[T] | T]

The PyTree to unwrap.

required
only_if Callable[[Any], bool]

An optional predicate function Callable[[Any], bool].

None
cascade bool

If True, unwrapping cascades to all descendants of a matched node. If False, only nodes strictly evaluating to True are unwrapped.

True

Returns:

Type Description
T

A new PyTree with the appropriate AbstractUnwrappable nodes resolved.

Source code in parax/wrappers.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def unwrap(
    tree: Union[AbstractUnwrappable[T] | T],
    only_if: Callable[[Any], bool] = None,
    cascade: bool = True,
) -> T:
    """Recursively resolves `AbstractUnwrappable` nodes within a PyTree.

    By default, unwrapping is performed inside-out (bottom-up) across the entire 
    tree. Every `AbstractUnwrappable` node is replaced by the result of its 
    `unwrap()` method.

    If the `only_if` predicate is provided, unwrapping is conditionally gated. 
    The `cascade` parameter controls how descendants of matching nodes are handled.

    Behavior with `only_if`:
        1. If `cascade=True` (default): The tree is searched top-down. Once a 
           node satisfies `only_if`, that node and ALL of its `AbstractUnwrappable` 
           descendants are fully resolved.
        2. If `cascade=False`: The tree is traversed bottom-up. Unwrapping ONLY 
           triggers for specific nodes that satisfy the condition. Unmatching 
           descendants are left wrapped.

    Args:
        tree: The PyTree to unwrap.
        only_if: An optional predicate function `Callable[[Any], bool]`.
        cascade: If True, unwrapping cascades to all descendants of a matched node. 
                 If False, only nodes strictly evaluating to True are unwrapped.


    Returns:
        A new PyTree with the appropriate `AbstractUnwrappable` nodes resolved.
    """
    def _do_unwrap(node, *, include_self: bool):
        """Unconditionally unwraps the node and all unwrappable descendants."""
        def _map_fn(leaf):
            if not is_unwrappable(leaf):
                return leaf
            resolved_node = _do_unwrap(leaf, include_self=False)
            return resolved_node.unwrap()

        def is_leaf(x):
            included = True if x is not node else include_self
            return is_unwrappable(x) and included

        return jax.tree.map(_map_fn, node, is_leaf=is_leaf)

    def _search_and_unwrap(node, *, include_self: bool):
        """Top-down search: cascades an unconditional unwrap upon first match."""
        if include_self and only_if(node):
            return _do_unwrap(node, include_self=True)

        def _map_fn(leaf):
            if only_if(leaf):
                return _do_unwrap(leaf, include_self=True)
            if is_unwrappable(leaf):
                return _search_and_unwrap(leaf, include_self=False)
            return leaf

        def is_leaf(x):
            included = True if x is not node else include_self
            return (is_unwrappable(x) or only_if(x)) and included

        return jax.tree.map(_map_fn, node, is_leaf=is_leaf)

    def _targeted_unwrap(node, *, include_self: bool):
        """Bottom-up search: strictly unwraps ONLY nodes matching the condition."""
        def _map_fn(leaf):
            if not is_unwrappable(leaf):
                return leaf

            # Recursively resolve children first (bottom-up)
            resolved_node = _targeted_unwrap(leaf, include_self=False)

            # Conditionally unwrap this specific node
            if only_if(leaf):
                return resolved_node.unwrap()

            return resolved_node

        def is_leaf(x):
            included = True if x is not node else include_self
            return is_unwrappable(x) and included

        return jax.tree.map(_map_fn, node, is_leaf=is_leaf)

    # Fast path: no condition
    if only_if is None:  
        return _do_unwrap(tree, include_self=True)    

    # Condition with cascading unwrap
    if cascade:
        return _search_and_unwrap(tree, include_self=True)

    # Condition strictly applied per-node
    return _targeted_unwrap(tree, include_self=True)

parax.unwrap_self(method)

Decorator for Equinox module methods. Unwraps self before executing the method so all ties/deferred parameters are resolved.

Source code in parax/wrappers.py
156
157
158
159
160
161
162
163
164
165
166
167
def unwrap_self(method):
    """
    Decorator for Equinox module methods. 
    Unwraps `self` before executing the method so all ties/deferred 
    parameters are resolved.
    """
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        # Explicitly use the data function on self
        unwrapped_self = unwrap(self)
        return method(unwrapped_self, *args, **kwargs)
    return wrapper

parax.wrap(template_tree, unwrapped_tree, only_if=None)

Recursively resolves AbstractWrappable nodes to reconstruct a wrapped PyTree.

This function maps across a template_tree and an unwrapped_tree simultaneously. Wrapping is performed outside-in (top-down), perfectly inverting the inside-out (bottom-up) process of parax.unwrap.

Note: This function is typically used to re-wrap a PyTree that was previously unwrapped via parax.unwrap and parax.AbstractUnwrappable.

If the only_if predicate is provided, the wrapping process is conditionally gated. The tree is searched top-down, and wrapping only triggers for subtrees that satisfy the condition. Once a node satisfies only_if, that entire subtree is fully wrapped.

Behavior with only_if: 1. If only_if(node) is True: The node and all of its AbstractWrappable descendants are fully wrapped. 2. If only_if(node) is False: The node itself bypasses wrapping, but the search continues recursively into its children.

Parameters:

Name Type Description Default
template_tree PyTree

The original (or blueprint) PyTree containing AbstractWrappable nodes.

required
unwrapped_tree PyTree

The PyTree containing the raw, unwrapped values.

required
only_if Callable[[Any], bool]

An optional predicate function Callable[[Any], bool]. If provided, only subtrees evaluating to True (and their descendants) are wrapped.

None

Returns:

Type Description
PyTree

A new PyTree where the appropriate values from unwrapped_tree have been

PyTree

wrapped using the template nodes.

Source code in parax/wrappers.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def wrap(template_tree: PyTree, unwrapped_tree: PyTree, only_if: Callable[[Any], bool] = None) -> PyTree:
    """Recursively resolves `AbstractWrappable` nodes to reconstruct a wrapped PyTree.

    This function maps across a `template_tree` and an `unwrapped_tree` simultaneously. 
    Wrapping is performed outside-in (top-down), perfectly inverting the inside-out 
    (bottom-up) process of `parax.unwrap`. 

    **Note:** This function is typically used to re-wrap a PyTree that was previously 
    unwrapped via `parax.unwrap` and `parax.AbstractUnwrappable`.

    If the `only_if` predicate is provided, the wrapping process is conditionally gated.
    The tree is searched top-down, and wrapping only triggers for subtrees that 
    satisfy the condition. Once a node satisfies `only_if`, that entire subtree 
    is fully wrapped.

    Behavior with `only_if`:
        1. If `only_if(node)` is True: The node and all of its `AbstractWrappable` 
           descendants are fully wrapped.
        2. If `only_if(node)` is False: The node itself bypasses wrapping, but the 
           search continues recursively into its children.

    Args:
        template_tree: The original (or blueprint) PyTree containing `AbstractWrappable` nodes.
        unwrapped_tree: The PyTree containing the raw, unwrapped values.
        only_if: An optional predicate function `Callable[[Any], bool]`. If provided, 
            only subtrees evaluating to True (and their descendants) are wrapped.

    Returns:
        A new PyTree where the appropriate values from `unwrapped_tree` have been 
        wrapped using the template nodes.
    """
    def _do_wrap(t_node, u_node, *, include_self: bool):
        def _map_fn(t_leaf, u_leaf):
            if not is_wrappable(t_leaf):
                return u_leaf

            partially_wrapped = t_leaf.wrap(u_leaf)
            return _do_wrap(t_leaf, partially_wrapped, include_self=False)

        def is_leaf(x):
            included = True if x is not t_node else include_self
            return is_wrappable(x) and included

        return jax.tree.map(_map_fn, t_node, u_node, is_leaf=is_leaf)

    def _search_and_wrap(t_node, u_node, *, include_self: bool):
        if include_self and only_if(t_node):
            return _do_wrap(t_node, u_node, include_self=True)

        def _map_fn(t_leaf, u_leaf):
            if only_if(t_leaf):
                return _do_wrap(t_leaf, u_leaf, include_self=True)

            if is_wrappable(t_leaf):
                # Bypass wrapping this node, but keep searching its children
                return _search_and_wrap(t_leaf, u_leaf, include_self=False)

            return u_leaf

        def is_leaf(x):
            included = True if x is not t_node else include_self
            return (is_wrappable(x) or only_if(x)) and included

        return jax.tree.map(_map_fn, t_node, u_node, is_leaf=is_leaf)

    if only_if is None:
        return _do_wrap(template_tree, unwrapped_tree, include_self=True)
    return _search_and_wrap(template_tree, unwrapped_tree, include_self=True)

parax.Freeze(tree)

Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractConstant[T]

Applies jax.lax.stop_gradient to all array-like leaves before unwrapping.

Implements the AbstractConstant interface so it can be filtered out during optimization partitioning.

Corner Case Note: The __init__ automatically prevents double-wrapping. If you pass a Freeze object into Freeze, it safely absorbs it rather than nesting them.

Parameters:

Name Type Description Default
tree T

The PyTree to freeze.

required
Source code in parax/wrappers.py
368
369
370
371
372
373
374
375
def __init__(self, tree: T):
    """
    Args:
        tree: The PyTree to freeze.
    """
    if isinstance(tree, Freeze):
        tree = tree.tree
    self.tree = tree

parax.Parameterize(fn, *args, **kwargs)

Bases: AbstractUnwrappable[T]

Unwrap into an arbitrary object by calling a function with arguments.

Useful for injecting dynamic generation (like neural network outputs or complex parametrizations) into an otherwise static PyTree structure.

Parameters:

Name Type Description Default
fn Callable

The callable to execute upon unwrapping.

required
*args Any

Positional arguments to pass to fn.

()
**kwargs Any

Keyword arguments to pass to fn.

{}
Source code in parax/wrappers.py
302
303
304
305
306
307
308
309
310
311
def __init__(self, fn: Callable, *args: Any, **kwargs: Any):
    """
    Args:
        fn: The callable to execute upon unwrapping.
        *args: Positional arguments to pass to `fn`.
        **kwargs: Keyword arguments to pass to `fn`.
    """
    self.fn = fn
    self.args = tuple(args)
    self.kwargs = kwargs

parax.Apply(fn, tree, *args, **kwargs)

Bases: AbstractUnwrappable[T]

Unwrap a PyTree by applying a function to its array-like leaves.

Corner Case Note: This relies on eqx.is_array_like. Non-array leaves (e.g., strings, standard integers, metadata) inside tree will be bypassed and left intact, while any array-like objects (including booleans) while be mapped.

Parameters:

Name Type Description Default
tree T

The target PyTree to map the computation over.

required
fn Callable

The function to apply to each array-like leaf in the tree.

required
*args Any

Positional arguments passed to fn after the leaf.

()
**kwargs Any

Keyword arguments passed to fn.

{}
Source code in parax/wrappers.py
332
333
334
335
336
337
338
339
340
341
342
343
def __init__(self, fn: Callable, tree: T, *args: Any, **kwargs: Any):
    """
    Args:
        tree: The target PyTree to map the computation over.
        fn: The function to apply to each array-like leaf in the tree.
        *args: Positional arguments passed to `fn` after the leaf.
        **kwargs: Keyword arguments passed to `fn`.
    """
    self.tree = tree
    self.fn = fn
    self.args = tuple(args)
    self.kwargs = kwargs

parax.Static(tree)

Bases: AbstractUnwrappable[T], AbstractWrappable[T]

Wraps a tree and marks it as static.

Parameters:

Name Type Description Default
tree T

The PyTree to freeze.

required
Source code in parax/wrappers.py
477
478
479
480
481
482
483
484
def __init__(self, tree: T):
    """
    Args:
        tree: The PyTree to freeze.
    """
    if isinstance(tree, Static):
        tree = tree.tree
    self.tree = tree

parax.Tie(tree, target, source, tie_fn=lambda x: x)

Bases: AbstractUnwrappable

A wrapper that ties subtrees/parameters together.

Upon initialization, any tied sources are replaced with placeholders. Then, during unwrap, values are fetched from the target tree and injected into the source tree.

Note that if when tieing and existing Tie, the paths referring to the underlying, untied model.

Attributes:

Name Type Description
tree Any

The underlying Equinox module or PyTree.

ties tuple

A static tuple of parameter ties, formatted as (target_extractor, source_extractor, tie_fn).

Example

class Gaussian(eqx.Module): ... mu: jax.Array ... sigma: jax.Array ... model = Gaussian(mu=jnp.array(1.0), sigma=jnp.array(1.0))

Tie sigma to always be 2x mu (tie_fn can also be left out for identity)

tied_model = Tie( ... tree=model, ... target=lambda m: m.sigma, ... source=lambda m: m.mu, ... tie_fn=lambda mu: mu * 2.0 ... )

Optimizers will now only see mu

opt_state = optax.sgd(0.1).init(eqx.filter(tied_model, eqx.is_inexact_array))

Unwrapping resolves the tie dynamically

active_model = unwrap(tied_model) print(active_model.sigma) # Output: 2.0

Initializes the Tie wrapper and strips the target parameter.

Parameters:

Name Type Description Default
tree Any

The root PyTree or Equinox module to wrap.

required
target Callable[[Any], Any]

A callable (lens) that extracts the parameter to be replaced (e.g., lambda m: m.layer.weight).

required
source Callable[[Any], Any]

A callable (lens) that extracts the parameter to draw values from (e.g., lambda m: m.layer.bias).

required
tie_fn Callable[[Any], Any]

An optional transformation function applied to the source parameter before injecting it into the target. Defaults to the identity function.

lambda x: x
Source code in parax/wrappers.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def __init__(
    self, 
    tree: Any, 
    target: Callable[[Any], Any], 
    source: Callable[[Any], Any], 
    tie_fn: Callable[[Any], Any] = lambda x: x
):
    """Initializes the Tie wrapper and strips the target parameter.

    Args:
        tree: The root PyTree or Equinox module to wrap.
        target: A callable (lens) that extracts the parameter to be replaced 
            (e.g., `lambda m: m.layer.weight`).
        source: A callable (lens) that extracts the parameter to draw values 
            from (e.g., `lambda m: m.layer.bias`).
        tie_fn: An optional transformation function applied to the source 
            parameter before injecting it into the target. Defaults to the 
            identity function.
    """
    base_tree = tree.tree if isinstance(tree, Tie) else tree
    stripped_tree = eqx.tree_at(target, base_tree, replace_fn=lambda _x: _TiePlaceholder())
    new_tie = (target, source, tie_fn)
    if isinstance(tree, Tie):
        self.ties = tree.ties + (new_tie,)
    else:
        self.ties = (new_tie,)

    self.tree = stripped_tree

parax.Bound

Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractBounded[T]

A wrapper to attach bounds to an arbitrary PyTree.

Implements parax.bounds.AbstractBounded.

Currently this wrapper does not check to ensure the leaf nodes lie within the bounds.

Attributes:

Name Type Description
tree T

The wrapped tree.

bounds tuple[T, T]

The tree's bounds as a tuple matching its structure.

parax.Constrain

Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractConstrainable[T]

A wrapper to attach a parax.constraints.AbstractConstraint to an arbitrary PyTree.

Assumes the original PyTree is unconstrained i.e. has array-like leaf nodes that are defined over the entire real number line.

Implements parax.constrainable.AbstractConstrainable.

Currently this wrapper does not check to ensure the leaf nodes lie within the constraints.

Attributes:

Name Type Description
tree T

The wrapped tree.

constraint AbstractConstraint

The tree's constraint.

parax.Probabilize(distribution, tree, constraint=None)

Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractProbabilistic[T]

A wrapper to add a probability distribution to an arbitrary PyTree.

Implements parax.probability.AbstractProbabilistic.

Attributes:

Name Type Description
distribution AbstractDistribution

The tree's associated probability distribution.

constraint AbstractConstraint

The tree and probability distribution's constraint. If not explicitly provided during initialization, this is automatically inferred from the distribution using parax.constraints.infer_distribution_constraint.

tree T

The wrapped tree.

Parameters:

Name Type Description Default
distribution AbstractDistribution

The probability distribution to associate with the tree.

required
tree T

The PyTree to be wrapped.

required
constraint AbstractConstraint | None

An optional explicit constraint. If None, it is attempted to automatically infer the constraint from the distribution.

None
Source code in parax/wrappers.py
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
def __init__(
    self, 
    distribution: AbstractDistribution, 
    tree: T, 
    constraint: AbstractConstraint | None = None
):
    """
    Args:
        distribution: The probability distribution to associate with the tree.
        tree: The PyTree to be wrapped.
        constraint: An optional explicit constraint. If `None`, it is attempted
            to automatically infer the constraint from the `distribution`.
    """
    self.distribution = distribution
    self.tree = tree

    if constraint is None:
        from parax.constraints import infer_distribution_constraint
        self.constraint = infer_distribution_constraint(distribution)
    else:
        self.constraint = constraint

parax.as_unwrapped(tree)

Conditionally unwraps the root node if it is an AbstractUnwrappable.

Unlike unwrap, this function does not recursively traverse the PyTree. It only evaluates the top-level object.

Parameters:

Name Type Description Default
tree Union[T, PyTree[T]]

The tree or node to potentially unwrap.

required

Returns:

Type Description
T

The unwrapped result if tree was an AbstractUnwrappable,

T

otherwise returns the original tree unmodified.

Source code in parax/wrappers.py
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def as_unwrapped(tree: Union[T, PyTree[T]]) -> T:
    """Conditionally unwraps the root node if it is an `AbstractUnwrappable`.

    Unlike `unwrap`, this function does not recursively traverse the PyTree. 
    It only evaluates the top-level object.

    Args:
        tree: The tree or node to potentially unwrap.

    Returns:
        The unwrapped result if `tree` was an `AbstractUnwrappable`, 
        otherwise returns the original `tree` unmodified.
    """
    if isinstance(tree, AbstractUnwrappable):
        return tree.unwrap()
    return tree

parax.as_opaque(tree)

Returns tree wrapped in either a parax.Static or parax.Freeze module, creating one if needed.

If tree is a JAX array or a structured PyTree, it is wrapped in parax.Freeze. If tree is an unregistered Python object (an opaque leaf e.g. a lambda), it is wrapped in parax.Static to safely bypass JAX transformations.

Parameters:

Name Type Description Default
tree Union[T | Static[T]]

An arbitrary PyTree, array, or Python object.

required

Returns:

Type Description
Union[Freeze, Static]

A static or freeze version of the tree. If it is already static or freeze, returns it directly.

Source code in parax/wrappers.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
def as_opaque(tree: Union[T | Static[T]]) -> Union[Freeze, Static]:
    """
    Returns `tree` wrapped in either a `parax.Static` or `parax.Freeze` module, creating one if needed.

    If `tree` is a JAX array or a structured PyTree, it is wrapped in `parax.Freeze`. 
    If `tree` is an unregistered Python object (an opaque leaf e.g. a lambda), it is wrapped in `parax.Static` 
    to safely bypass JAX transformations.

    Args:
        tree: An arbitrary PyTree, array, or Python object.

    Returns:
        A static or freeze version of the tree. If it is already static or freeze, returns it directly.
    """
    if isinstance(tree, Freeze | Static):
        return tree

    # Ask JAX how it views this object
    # If JAX can't unpack it, it returns a list containing exactly the object itself.
    leaves, _ = jax.tree_util.tree_flatten(tree)
    is_opaque_leaf = (len(leaves) == 1) and (leaves[0] is tree)

    if is_opaque_leaf and not eqx.is_array(tree):
        return Static(tree)
    return Freeze(tree)