Unwrappable
parax.AbstractUnwrappable
Bases: Module, Generic[T]
An abstract class representing a deferred or wrapped PyTree node.
Unwrappables act as placeholders within a PyTree. When parax.unwrap
is called on the tree, these nodes execute custom logic (such as delayed
computation, parameter injection, or gradient stopping) and replace
themselves with their output.
unwrap()
abstractmethod
Evaluates and returns the underlying unwrapped PyTree node.
Returns:
| Type | Description |
|---|---|
T
|
The resolved underlying value, assuming no wrapped subnodes exist. |
Source code in parax/unwrappable.py
24 25 26 27 28 29 30 31 | |
parax.unwrap(tree, only_if=None)
Recursively resolves AbstractUnwrappable nodes within a PyTree.
By default, unwrapping is performed inside-out (bottom-up) across the entire
tree. Every AbstractUnwrappable node is replaced by the result of its
unwrap() method.
If the only_if predicate is provided, unwrapping is conditionally gated.
The tree is searched top-down, and unwrapping only triggers for subtrees
that satisfy the condition. Once a node satisfies only_if, that entire
subtree is fully unwrapped.
Behavior with only_if:
1. If only_if(node) is True: The node and all of its AbstractUnwrappable
descendants are fully resolved.
2. If only_if(node) is False: The node is left wrapped, but the search
continues recursively into its children.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
Any
|
The PyTree to unwrap. |
required |
only_if
|
Callable[[Any], bool]
|
An optional predicate function |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
A new PyTree with the appropriate |
Source code in parax/unwrappable.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | |