Skip to content

Wrappers

parax.Frozen(tree)

Bases: AbstractUnwrappable[T], AbstractWrappable[T], AbstractConstant[T]

Applies jax.lax.stop_gradient to all array-like leaves before unwrapping.

Implements the AbstractConstant interface so it can be filtered out during optimization partitioning.

Corner Case Note: The __init__ automatically prevents double-wrapping. If you pass a Frozen object into Frozen, it safely absorbs it rather than nesting them.

Parameters:

Name Type Description Default
tree T

The PyTree to freeze.

required
Source code in parax/wrappers.py
100
101
102
103
104
105
106
107
def __init__(self, tree: T):
    """
    Args:
        tree: The PyTree to freeze.
    """
    if isinstance(tree, Frozen):
        tree = tree.tree
    self.tree = tree

parax.Parameterized(fn, *args, **kwargs)

Bases: AbstractUnwrappable[T]

Unwrap into an arbitrary object by calling a function with arguments.

Useful for injecting dynamic generation (like neural network outputs or complex parametrizations) into an otherwise static PyTree structure.

Parameters:

Name Type Description Default
fn Callable

The callable to execute upon unwrapping.

required
*args Any

Positional arguments to pass to fn.

()
**kwargs Any

Keyword arguments to pass to fn.

{}
Source code in parax/wrappers.py
34
35
36
37
38
39
40
41
42
43
def __init__(self, fn: Callable, *args: Any, **kwargs: Any):
    """
    Args:
        fn: The callable to execute upon unwrapping.
        *args: Positional arguments to pass to `fn`.
        **kwargs: Keyword arguments to pass to `fn`.
    """
    self.fn = fn
    self.args = tuple(args)
    self.kwargs = kwargs

parax.Computed(fn, tree, *args, **kwargs)

Bases: AbstractUnwrappable[T]

Unwrap a PyTree by applying a function to its array-like leaves.

Corner Case Note: This relies on eqx.is_array_like. Non-array leaves (e.g., strings, standard integers, metadata) inside tree will be bypassed and left intact, while any array-like objects (including booleans) while be mapped.

Parameters:

Name Type Description Default
tree T

The target PyTree to map the computation over.

required
fn Callable

The function to apply to each array-like leaf in the tree.

required
*args Any

Positional arguments passed to fn after the leaf.

()
**kwargs Any

Keyword arguments passed to fn.

{}
Source code in parax/wrappers.py
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(self, fn: Callable, tree: T, *args: Any, **kwargs: Any):
    """
    Args:
        tree: The target PyTree to map the computation over.
        fn: The function to apply to each array-like leaf in the tree.
        *args: Positional arguments passed to `fn` after the leaf.
        **kwargs: Keyword arguments passed to `fn`.
    """
    self.tree = tree
    self.fn = fn
    self.args = tuple(args)
    self.kwargs = kwargs

parax.Static(tree)

Bases: AbstractUnwrappable[T], AbstractWrappable[T]

Wraps a tree and marks it as static.

Parameters:

Name Type Description Default
tree T

The PyTree to freeze.

required
Source code in parax/wrappers.py
141
142
143
144
145
146
147
148
def __init__(self, tree: T):
    """
    Args:
        tree: The PyTree to freeze.
    """
    if isinstance(tree, Static):
        tree = tree.tree
    self.tree = tree