Skip to content

Unwrappable

parax.AbstractUnwrappable

Bases: Module, Generic[T]

An abstract class representing a deferred or wrapped PyTree node.

Unwrappables act as placeholders within a PyTree. When parax.unwrap is called on the tree, these nodes execute custom logic (such as delayed computation, parameter injection, or gradient stopping) and replace themselves with their output.

unwrap() abstractmethod

Evaluates and returns the underlying unwrapped PyTree node.

Returns:

Type Description
T

The resolved underlying value, assuming no wrapped subnodes exist.

Source code in parax/unwrappable.py
24
25
26
27
28
29
30
31
@abstractmethod
def unwrap(self) -> T:
    """Evaluates and returns the underlying unwrapped PyTree node.

    Returns:
        The resolved underlying value, assuming no wrapped subnodes exist.
    """
    pass

parax.unwrap(tree, only_if=None)

Recursively resolves AbstractUnwrappable nodes within a PyTree.

By default, unwrapping is performed inside-out (bottom-up) across the entire tree. Every AbstractUnwrappable node is replaced by the result of its unwrap() method.

If the only_if predicate is provided, unwrapping is conditionally gated. The tree is searched top-down, and unwrapping only triggers for subtrees that satisfy the condition. Once a node satisfies only_if, that entire subtree is fully unwrapped.

Behavior with only_if: 1. If only_if(node) is True: The node and all of its AbstractUnwrappable descendants are fully resolved. 2. If only_if(node) is False: The node is left wrapped, but the search continues recursively into its children.

Parameters:

Name Type Description Default
tree Any

The PyTree to unwrap.

required
only_if Callable[[Any], bool]

An optional predicate function Callable[[Any], bool]. If provided, only subtrees evaluating to True (and their descendants) are unwrapped.

None

Returns:

Type Description
Any

A new PyTree with the appropriate AbstractUnwrappable nodes resolved.

Source code in parax/unwrappable.py
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def unwrap(tree: Any, only_if: Callable[[Any], bool] = None) -> Any:
    """Recursively resolves `AbstractUnwrappable` nodes within a PyTree.

    By default, unwrapping is performed inside-out (bottom-up) across the entire 
    tree. Every `AbstractUnwrappable` node is replaced by the result of its 
    `unwrap()` method.

    If the `only_if` predicate is provided, unwrapping is conditionally gated. 
    The tree is searched top-down, and unwrapping only triggers for subtrees 
    that satisfy the condition. Once a node satisfies `only_if`, that entire 
    subtree is fully unwrapped. 

    Behavior with `only_if`:
        1. If `only_if(node)` is True: The node and all of its `AbstractUnwrappable` 
           descendants are fully resolved.
        2. If `only_if(node)` is False: The node is left wrapped, but the search 
           continues recursively into its children.

    Args:
        tree: The PyTree to unwrap.
        only_if: An optional predicate function `Callable[[Any], bool]`. If provided, 
            only subtrees evaluating to True (and their descendants) are unwrapped.

    Returns:
        A new PyTree with the appropriate `AbstractUnwrappable` nodes resolved.
    """
    def _do_unwrap(node, *, include_self: bool):
        def _map_fn(leaf):
            if not is_unwrappable(leaf):
                return leaf
            # Recursively unwrap the children first (bottom-up)
            resolved_node = _do_unwrap(leaf, include_self=False)
            return resolved_node.unwrap()

        def is_leaf(x):
            included = True if x is not node else include_self
            return is_unwrappable(x) and included

        return jax.tree.map(_map_fn, node, is_leaf=is_leaf)

    def _search_and_unwrap(node, *, include_self: bool):
        if include_self and only_if(node):
            return _do_unwrap(node, include_self=True)

        def _map_fn(leaf):
            if only_if(leaf):
                return _do_unwrap(leaf, include_self=True)

            if is_unwrappable(leaf):
                return _search_and_unwrap(leaf, include_self=False)

            return leaf

        def is_leaf(x):
            included = True if x is not node else include_self
            return (is_unwrappable(x) or only_if(x)) and included

        return jax.tree.map(_map_fn, node, is_leaf=is_leaf)

    if only_if is None:  # fast path
        return _do_unwrap(tree, include_self=True)    
    return _search_and_unwrap(tree, include_self=True)