Skip to content

Unwrappables

parax.AbstractUnwrappable

Bases: Module, Generic[T]

An abstract class representing a deferred or wrapped PyTree node.

Unwrappables act as placeholders within a PyTree. When parax.unwrap is called on the tree, these nodes execute custom logic (like computation or gradient stopping) and replace themselves with their output.

unwrap() abstractmethod

Returns the unwrapped pytree, assuming no wrapped subnodes exist.

Source code in parax/unwrappables.py
28
29
30
31
@abstractmethod
def unwrap(self) -> T:
    """Returns the unwrapped pytree, assuming no wrapped subnodes exist."""
    pass

parax.unwrap(tree)

Map across a PyTree and recursively resolve all AbstractUnwrappable nodes.

Corner Case Note: This function handles nested unwrappables from the inside out. If an AbstractUnwrappable contains other unwrappables, the inner nodes are recursively unwrapped before the outer node's .unwrap() method is called.

Source code in parax/unwrappables.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def unwrap(tree: PyTree):
    """
    Map across a PyTree and recursively resolve all `AbstractUnwrappable` nodes.

    **Corner Case Note:** This function handles nested unwrappables from the 
    inside out. If an `AbstractUnwrappable` contains other unwrappables, the 
    inner nodes are recursively unwrapped *before* the outer node's `.unwrap()` 
    method is called.
    """

    def _unwrap(tree, *, include_self: bool):
        def _map_fn(leaf):
            if isinstance(leaf, AbstractUnwrappable):
                # Unwrap subnodes, then itself
                return _unwrap(leaf, include_self=False).unwrap()
            return leaf

        def is_leaf(x):
            is_unwrappable = isinstance(x, AbstractUnwrappable)
            included = include_self or x is not tree
            return is_unwrappable and included

        return jax.tree_util.tree_map(f=_map_fn, tree=tree, is_leaf=is_leaf)

    return _unwrap(tree, include_self=True)

parax.Frozen(tree)

Bases: AbstractUnwrappable[T], AbstractConstant[T]

Applies jax.lax.stop_gradient to all array-like leaves before unwrapping.

Implements the AbstractConstant interface so it can be filtered out during optimization partitioning.

Corner Case Note: The __init__ automatically prevents double-wrapping. If you pass a Frozen object into Frozen, it safely absorbs it rather than nesting them.

Parameters:

Name Type Description Default
tree T

The PyTree to freeze.

required
Source code in parax/unwrappables.py
139
140
141
142
143
144
145
146
def __init__(self, tree: T):
    """
    Args:
        tree: The PyTree to freeze.
    """
    if isinstance(tree, Frozen):
        tree = tree.tree
    self.tree = tree

parax.Parameterized(fn, *args, **kwargs)

Bases: AbstractUnwrappable[T]

Unwrap into an arbitrary object by calling a function with arguments.

Useful for injecting dynamic generation (like neural network outputs or complex parametrizations) into an otherwise static PyTree structure.

Parameters:

Name Type Description Default
fn Callable

The callable to execute upon unwrapping.

required
*args

Positional arguments to pass to fn.

()
**kwargs

Keyword arguments to pass to fn.

{}
Source code in parax/unwrappables.py
73
74
75
76
77
78
79
80
81
82
def __init__(self, fn: Callable, *args, **kwargs):
    """
    Args:
        fn: The callable to execute upon unwrapping.
        *args: Positional arguments to pass to `fn`.
        **kwargs: Keyword arguments to pass to `fn`.
    """
    self.fn = fn
    self.args = tuple(args)
    self.kwargs = kwargs

parax.Computed(fn, tree, *args, **kwargs)

Bases: AbstractUnwrappable[T]

Unwrap a PyTree by applying a function to its array-like leaves.

Corner Case Note: This relies on eqx.is_array_like. Non-array leaves (e.g., strings, standard integers, metadata) inside tree will be bypassed and left intact, while any array-like objects (including booleans) while be mapped.

Parameters:

Name Type Description Default
tree T

The target PyTree to map the computation over.

required
fn Callable

The function to apply to each array-like leaf in the tree.

required
*args

Positional arguments passed to fn after the leaf.

()
**kwargs

Keyword arguments passed to fn.

{}
Source code in parax/unwrappables.py
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(self, fn: Callable, tree: T, *args, **kwargs):
    """
    Args:
        tree: The target PyTree to map the computation over.
        fn: The function to apply to each array-like leaf in the tree.
        *args: Positional arguments passed to `fn` after the leaf.
        **kwargs: Keyword arguments passed to `fn`.
    """
    self.tree = tree
    self.fn = fn
    self.args = tuple(args)
    self.kwargs = kwargs