Skip to content

Bounds

parax.bounds.AbstractBounded

Bases: Module, Generic[T]

The abstract interface for a bounded PyTree.

Used as a type check for parax.is_bounded.

Attributes:

Name Type Description
bounds AbstractVar[tuple[T, T]]

Returns the current PyTree bounds. Each must have a matching PyTree structure as self.

parax.bounds.tree_lower(tree)

Extracts the lower bounds of a potentially bounded PyTree.

Standard arrays default to (-inf, inf).

Note that this function does not allow non-array/bounded leaf nodes. If you have leaves in your tree that are neither arrays nor derive from parax.bounds.AbstractBounded, be sure to mark them as static or filter them out using e.g. eqx.filter first.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model to extract lower bounds from.

required

Returns:

Type Description
PyTree

A PyTree representing the lower bounds.

Source code in parax/bounds.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def tree_lower(tree: PyTree) -> PyTree:
    """
    Extracts the lower bounds of a potentially bounded PyTree. 

    Standard arrays default to (-inf, inf).

    Note that this function does not allow non-array/bounded leaf nodes.
    If you have leaves in your tree that are neither arrays nor derive
    from `parax.bounds.AbstractBounded`, be sure to mark
    them as static or filter them out using e.g. `eqx.filter` first.    

    Args:
        tree: The PyTree model to extract lower bounds from.

    Returns:
        A PyTree representing the lower bounds.
    """
    def _get_lower(path, x):
        if is_bounded(x):
            return x.bounds[0]
        if eqx.is_inexact_array(x):
            return jnp.full_like(x, -jnp.inf)
        raise ValueError(f"Found a leaf node of type {type(x)} that is neither bounded nor an array in `parax.bounds.tree_lower`. Value: {x}, path: {path}")

    lower = jax.tree.map_with_path(_get_lower, tree, is_leaf=is_bounded)
    return lower

parax.bounds.tree_upper(tree)

Extracts the upper bounds of a potentially bounded PyTree.

Standard arrays default to (-inf, inf).

Note that this function does not allow non-array/bounded leaf nodes. If you have leaves in your tree that are neither arrays nor derive from parax.bounds.AbstractBounded, be sure to mark them as static or filter them out using e.g. eqx.filter first.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model to extract upper bounds from.

required

Returns:

Type Description
PyTree

A PyTree representing the upper bounds.

Source code in parax/bounds.py
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def tree_upper(tree: PyTree) -> PyTree:
    """
    Extracts the upper bounds of a potentially bounded PyTree. 

    Standard arrays default to (-inf, inf).

    Note that this function does not allow non-array/bounded leaf nodes.
    If you have leaves in your tree that are neither arrays nor derive
    from `parax.bounds.AbstractBounded`, be sure to mark
    them as static or filter them out using e.g. `eqx.filter` first.    

    Args:
        tree: The PyTree model to extract upper bounds from.

    Returns:
        A PyTree representing the upper bounds.
    """
    def _get_upper(path, x):
        if is_bounded(x):
            return x.bounds[1]
        if eqx.is_inexact_array(x):
            return jnp.full_like(x, jnp.inf)
        raise ValueError(f"Found a leaf node of type {type(x)} that is neither bounded nor an array in `parax.bounds.tree_upper`. Value: {x}, path: {path}")

    upper = jax.tree.map_with_path(_get_upper, tree, is_leaf=is_bounded)
    return upper

parax.bounds.tree_bounds(tree)

Extracts two PyTrees (lower and upper) representing the boundaries.

Standard arrays default to (-inf, inf).

Note that this function does not allow non-array/bounded leaf nodes. If you have leaves in your tree that are neither arrays nor derive from parax.bounds.AbstractBounded, be sure to mark them as static or filter them out using e.g. eqx.filter first.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model to extract bounds from.

required

Returns:

Type Description
tuple[PyTree, PyTree]

A tuple of two PyTrees (lower_bounds, upper_bounds).

Source code in parax/bounds.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def tree_bounds(tree: PyTree) -> tuple[PyTree, PyTree]:
    """
    Extracts two PyTrees (lower and upper) representing the boundaries. 

    Standard arrays default to (-inf, inf).

    Note that this function does not allow non-array/bounded leaf nodes.
    If you have leaves in your tree that are neither arrays nor derive
    from `parax.bounds.AbstractBounded`, be sure to mark
    them as static or filter them out using e.g. `eqx.filter` first.    

    Args:
        tree: The PyTree model to extract bounds from.

    Returns:
        A tuple of two PyTrees `(lower_bounds, upper_bounds)`.
    """
    return tree_lower(tree), tree_upper(tree)