unwrap (pmrf.unwrap)
- pmrf.unwrap(tree: AbstractUnwrappable[T] | T, only_if: Callable[[Any], bool] = None) T
Recursively resolves AbstractUnwrappable nodes within a PyTree.
By default, unwrapping is performed inside-out (bottom-up) across the entire tree. Every AbstractUnwrappable node is replaced by the result of its unwrap() method.
If the only_if predicate is provided, unwrapping is conditionally gated. The tree is searched top-down, and unwrapping only triggers for subtrees that satisfy the condition. Once a node satisfies only_if, that entire subtree is fully unwrapped.
- Behavior with only_if:
If only_if(node) is True: The node and all of its AbstractUnwrappable descendants are fully resolved.
If only_if(node) is False: The node is left wrapped, but the search continues recursively into its children.
- Args:
tree: The PyTree to unwrap. only_if: An optional predicate function Callable[[Any], bool]. If provided,
only subtrees evaluating to True (and their descendants) are unwrapped.
- Returns:
A new PyTree with the appropriate AbstractUnwrappable nodes resolved.