Skip to content

Bounded

parax.bounded.AbstractBounded

Bases: Module, Generic[Base]

The abstract interface for a bounded PyTree.

Makes use of the concept of a "base" space where bounded optimizers operate.

Used as a type check for parax.is_bounded.

base abstractmethod property

Returns the current PyTree in base space.

bounds abstractmethod property

Returns the current PyTree bounds in base space.

Must have a matching PyTree structure as self.base.

update(base) abstractmethod

Returns a new instance of this object updated with a new base PyTree.

Parameters:

Name Type Description Default
base Base

The new base-space PyTree representing the updated state.

required

Returns:

Type Description
AbstractBounded

A new instance of the bounded object, updated to reflect the new base.

Source code in parax/bounded.py
40
41
42
43
44
45
46
47
48
49
50
51
@abstractmethod
def update(self, base: Base) -> "AbstractBounded":
    """
    Returns a new instance of this object updated with a new base PyTree.

    Args:
        base: The new base-space PyTree representing the updated state.

    Returns:
        A new instance of the bounded object, updated to reflect the new base.
    """
    pass

parax.bounded.tree_base(model)

Extracts a PyTree of base values from a model.

Standard inexact arrays are left intact.

Parameters:

Name Type Description Default
model PyTree

The original PyTree model potentially containing bounded nodes.

required

Returns:

Type Description
PyTree

A PyTree containing the extracted base values.

Source code in parax/bounded.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def tree_base(model: PyTree) -> PyTree:
    """
    Extracts a PyTree of base values from a model. 

    Standard inexact arrays are left intact.

    Args:
        model: The original PyTree model potentially containing bounded nodes.

    Returns:
        A PyTree containing the extracted base values.
    """
    from parax.filters import is_bounded
    def _extract(x):
        if not is_bounded(x):
            return x
        return x.base

    return jax.tree_util.tree_map(_extract, model, is_leaf=is_bounded)

parax.bounded.tree_lower(tree)

Extracts the lower bounds of a potentially bounded PyTree in base space.

Standard arrays default to (-inf, inf).

Parameters:

Name Type Description Default
tree PyTree

The PyTree model to extract lower bounds from.

required

Returns:

Type Description
PyTree

A PyTree representing the lower bounds in base space.

Source code in parax/bounded.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def tree_lower(tree: PyTree) -> PyTree:
    """
    Extracts the lower bounds of a potentially bounded PyTree in base space. 

    Standard arrays default to (-inf, inf).

    Args:
        tree: The PyTree model to extract lower bounds from.

    Returns:
        A PyTree representing the lower bounds in base space.
    """
    from parax.filters import is_bounded

    def _get_lower(x):
        if is_bounded(x):
            return x.bounds[0]
        if eqx.is_inexact_array(x):
            return jnp.full_like(x, -jnp.inf)
        return x

    lower = jax.tree_util.tree_map(_get_lower, tree, is_leaf=is_bounded)
    return lower

parax.bounded.tree_upper(tree)

Extracts the upper bounds of a potentially bounded PyTree in base space.

Standard arrays default to (-inf, inf).

Parameters:

Name Type Description Default
tree PyTree

The PyTree model to extract upper bounds from.

required

Returns:

Type Description
PyTree

A PyTree representing the upper bounds in base space.

Source code in parax/bounded.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def tree_upper(tree: PyTree) -> PyTree:
    """
    Extracts the upper bounds of a potentially bounded PyTree in base space. 

    Standard arrays default to (-inf, inf).

    Args:
        tree: The PyTree model to extract upper bounds from.

    Returns:
        A PyTree representing the upper bounds in base space.
    """
    from parax.filters import is_bounded

    def _get_upper(x):
        if is_bounded(x):
            return x.bounds[1]
        if eqx.is_inexact_array(x):
            return jnp.full_like(x, jnp.inf)
        return x

    upper = jax.tree_util.tree_map(_get_upper, tree, is_leaf=is_bounded)
    return upper

parax.bounded.tree_bounds(tree)

Extracts two PyTrees (lower and upper) representing the boundaries of the base space.

Standard arrays default to (-inf, inf).

Parameters:

Name Type Description Default
tree PyTree

The PyTree model to extract bounds from.

required

Returns:

Type Description
tuple[PyTree, PyTree]

A tuple of two PyTrees (lower_bounds, upper_bounds).

Source code in parax/bounded.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def tree_bounds(tree: PyTree) -> tuple[PyTree, PyTree]:
    """
    Extracts two PyTrees (lower and upper) representing the boundaries of 
    the base space. 

    Standard arrays default to (-inf, inf).

    Args:
        tree: The PyTree model to extract bounds from.

    Returns:
        A tuple of two PyTrees `(lower_bounds, upper_bounds)`.
    """
    return tree_lower(tree), tree_upper(tree)

parax.bounded.tree_update(model, base_model)

Takes an updated base-space PyTree and injects it back into the original bounded model structure using update_from_base.

Parameters:

Name Type Description Default
model PyTree

The original PyTree model containing the bounded nodes.

required
base_model PyTree

The updated PyTree containing the new base values.

required

Returns:

Type Description
PyTree

A new PyTree model with its internal states reconstructed to reflect

PyTree

the updated base values.

Source code in parax/bounded.py
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def tree_update(model: PyTree, base_model: PyTree) -> PyTree:
    """
    Takes an updated base-space PyTree and injects it back into the 
    original bounded model structure using `update_from_base`.

    Args:
        model: The original PyTree model containing the bounded nodes.
        base_model: The updated PyTree containing the new base values.

    Returns:
        A new PyTree model with its internal states reconstructed to reflect 
        the updated base values.
    """
    from parax.filters import is_bounded

    def _rebuild(orig, base):
        if is_bounded(orig):
            return orig.update(base)
        return base

    return jax.tree_util.tree_map(_rebuild, model, base_model, is_leaf=is_bounded)