Unwrappables
parax.AbstractUnwrappable
Bases: Module, Generic[T]
An abstract class representing a deferred or wrapped PyTree node.
Unwrappables act as placeholders within a PyTree. When parax.unwrap
is called on the tree, these nodes execute custom logic (like computation
or gradient stopping) and replace themselves with their output.
unwrap()
abstractmethod
Returns the unwrapped pytree, assuming no wrapped subnodes exist.
Source code in parax/unwrappables.py
28 29 30 31 | |
parax.unwrap(tree)
Map across a PyTree and recursively resolve all AbstractUnwrappable nodes.
Corner Case Note: This function handles nested unwrappables from the
inside out. If an AbstractUnwrappable contains other unwrappables, the
inner nodes are recursively unwrapped before the outer node's .unwrap()
method is called.
Source code in parax/unwrappables.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | |
parax.Frozen(tree)
Bases: AbstractUnwrappable[T], AbstractConstant[T]
Applies jax.lax.stop_gradient to all array-like leaves before unwrapping.
Implements the AbstractConstant interface so it can be filtered out
during optimization partitioning.
Corner Case Note: The __init__ automatically prevents double-wrapping.
If you pass a Frozen object into Frozen, it safely absorbs it rather
than nesting them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The PyTree to freeze. |
required |
Source code in parax/unwrappables.py
139 140 141 142 143 144 145 146 | |
parax.Parameterized(fn, *args, **kwargs)
Bases: AbstractUnwrappable[T]
Unwrap into an arbitrary object by calling a function with arguments.
Useful for injecting dynamic generation (like neural network outputs or complex parametrizations) into an otherwise static PyTree structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
The callable to execute upon unwrapping. |
required |
*args
|
Positional arguments to pass to |
()
|
|
**kwargs
|
Keyword arguments to pass to |
{}
|
Source code in parax/unwrappables.py
73 74 75 76 77 78 79 80 81 82 | |
parax.Computed(fn, tree, *args, **kwargs)
Bases: AbstractUnwrappable[T]
Unwrap a PyTree by applying a function to its array-like leaves.
Corner Case Note: This relies on eqx.is_array_like. Non-array
leaves (e.g., strings, standard integers, metadata) inside tree
will be bypassed and left intact, while any array-like objects
(including booleans) while be mapped.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The target PyTree to map the computation over. |
required |
fn
|
Callable
|
The function to apply to each array-like leaf in the tree. |
required |
*args
|
Positional arguments passed to |
()
|
|
**kwargs
|
Keyword arguments passed to |
{}
|
Source code in parax/unwrappables.py
103 104 105 106 107 108 109 110 111 112 113 114 | |