Wrappers
parax.Frozen(tree)
Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractConstant[T]
Applies jax.lax.stop_gradient to all array-like leaves before unwrapping.
Implements the AbstractConstant interface so it can be filtered out
during optimization partitioning.
Corner Case Note: The __init__ automatically prevents double-wrapping.
If you pass a Frozen object into Frozen, it safely absorbs it rather
than nesting them.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The PyTree to freeze. |
required |
Source code in parax/wrappers.py
100 101 102 103 104 105 106 107 | |
parax.Parameterized(fn, *args, **kwargs)
Bases: AbstractUnwrappable[T]
Unwrap into an arbitrary object by calling a function with arguments.
Useful for injecting dynamic generation (like neural network outputs or complex parametrizations) into an otherwise static PyTree structure.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
The callable to execute upon unwrapping. |
required |
*args
|
Any
|
Positional arguments to pass to |
()
|
**kwargs
|
Any
|
Keyword arguments to pass to |
{}
|
Source code in parax/wrappers.py
34 35 36 37 38 39 40 41 42 43 | |
parax.Computed(fn, tree, *args, **kwargs)
Bases: AbstractUnwrappable[T]
Unwrap a PyTree by applying a function to its array-like leaves.
Corner Case Note: This relies on eqx.is_array_like. Non-array
leaves (e.g., strings, standard integers, metadata) inside tree
will be bypassed and left intact, while any array-like objects
(including booleans) while be mapped.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The target PyTree to map the computation over. |
required |
fn
|
Callable
|
The function to apply to each array-like leaf in the tree. |
required |
*args
|
Any
|
Positional arguments passed to |
()
|
**kwargs
|
Any
|
Keyword arguments passed to |
{}
|
Source code in parax/wrappers.py
64 65 66 67 68 69 70 71 72 73 74 75 | |
parax.Static(tree)
Bases: AbstractUnwrappable[T], AbstractWrappable[T]
Wraps a tree and marks it as static.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
T
|
The PyTree to freeze. |
required |
Source code in parax/wrappers.py
141 142 143 144 145 146 147 148 | |