Bounded
parax.bounded.AbstractBounded
Bases: Module, Generic[Base]
The abstract interface for a bounded PyTree.
Makes use of the concept of a "base" space where bounded optimizers operate.
Used as a type check for parax.is_bounded.
base
abstractmethod
property
Returns the current PyTree in base space.
bounds
abstractmethod
property
Returns the current PyTree bounds in base space.
Must have a matching PyTree structure as self.base.
update(base)
abstractmethod
Returns a new instance of this object updated with a new base PyTree.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
Base
|
The new base-space PyTree representing the updated state. |
required |
Returns:
| Type | Description |
|---|---|
AbstractBounded
|
A new instance of the bounded object, updated to reflect the new base. |
Source code in parax/bounded.py
40 41 42 43 44 45 46 47 48 49 50 51 | |
parax.bounded.tree_base(model)
Extracts a PyTree of base values from a model.
Standard inexact arrays are left intact.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PyTree
|
The original PyTree model potentially containing bounded nodes. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree containing the extracted base values. |
Source code in parax/bounded.py
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | |
parax.bounded.tree_lower(tree)
Extracts the lower bounds of a potentially bounded PyTree in base space.
Standard arrays default to (-inf, inf).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract lower bounds from. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree representing the lower bounds in base space. |
Source code in parax/bounded.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | |
parax.bounded.tree_upper(tree)
Extracts the upper bounds of a potentially bounded PyTree in base space.
Standard arrays default to (-inf, inf).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract upper bounds from. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree representing the upper bounds in base space. |
Source code in parax/bounded.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 | |
parax.bounded.tree_bounds(tree)
Extracts two PyTrees (lower and upper) representing the boundaries of the base space.
Standard arrays default to (-inf, inf).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract bounds from. |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree]
|
A tuple of two PyTrees |
Source code in parax/bounded.py
125 126 127 128 129 130 131 132 133 134 135 136 137 138 | |
parax.bounded.tree_update(model, base_model)
Takes an updated base-space PyTree and injects it back into the
original bounded model structure using update_from_base.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
PyTree
|
The original PyTree model containing the bounded nodes. |
required |
base_model
|
PyTree
|
The updated PyTree containing the new base values. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A new PyTree model with its internal states reconstructed to reflect |
PyTree
|
the updated base values. |
Source code in parax/bounded.py
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | |