Wrappable
parax.AbstractWrappable
Bases: Module, Generic[T]
An abstract class representing a PyTree node capable of wrapping another tree.
This interface is the counterpart to parax.AbstractUnwrappable. It is
typically used to define how an object should reconstruct or "re-wrap" a
tree that was previously unwrapped.
wrap(tree)
abstractmethod
Wraps the provided tree inside this node's structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The unwrapped tree or raw value to be wrapped. |
required |
Returns:
| Type | Description |
|---|---|
Self
|
A new instance of this node containing the wrapped tree. |
Source code in parax/wrappable.py
24 25 26 27 28 29 30 31 32 33 34 | |
parax.wrap(template_tree, unwrapped_tree, only_if=None)
Recursively resolves AbstractWrappable nodes to reconstruct a wrapped PyTree.
This function maps across a template_tree and an unwrapped_tree simultaneously.
Wrapping is performed outside-in (top-down), perfectly inverting the inside-out
(bottom-up) process of parax.unwrap.
Note: This function is typically used to re-wrap a PyTree that was previously
unwrapped via parax.unwrap and parax.AbstractUnwrappable.
If the only_if predicate is provided, the wrapping process is conditionally gated.
The tree is searched top-down, and wrapping only triggers for subtrees that
satisfy the condition. Once a node satisfies only_if, that entire subtree
is fully wrapped.
Behavior with only_if:
1. If only_if(node) is True: The node and all of its AbstractWrappable
descendants are fully wrapped.
2. If only_if(node) is False: The node itself bypasses wrapping, but the
search continues recursively into its children.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
template_tree
|
PyTree
|
The original (or blueprint) PyTree containing |
required |
unwrapped_tree
|
PyTree
|
The PyTree containing the raw, unwrapped values. |
required |
only_if
|
Callable[[Any], bool]
|
An optional predicate function |
None
|
Returns:
| Type | Description |
|---|---|
PyTree
|
A new PyTree where the appropriate values from |
PyTree
|
wrapped using the template nodes. |
Source code in parax/wrappable.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | |