Bounds
parax.bounds.AbstractBounded
Bases: Module, Generic[T]
The abstract interface for a bounded PyTree.
Used as a type check for parax.is_bounded.
Attributes:
| Name | Type | Description |
|---|---|---|
bounds |
AbstractVar[tuple[T, T]]
|
Returns the current PyTree bounds. Each must have a matching PyTree structure as |
parax.bounds.tree_lower(tree)
Extracts the lower bounds of a potentially bounded PyTree.
Standard arrays default to (-inf, inf).
Note that this function does not allow non-array/bounded leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.bounds.AbstractBounded, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract lower bounds from. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree representing the lower bounds. |
Source code in parax/bounds.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | |
parax.bounds.tree_upper(tree)
Extracts the upper bounds of a potentially bounded PyTree.
Standard arrays default to (-inf, inf).
Note that this function does not allow non-array/bounded leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.bounds.AbstractBounded, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract upper bounds from. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree representing the upper bounds. |
Source code in parax/bounds.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | |
parax.bounds.tree_bounds(tree)
Extracts two PyTrees (lower and upper) representing the boundaries.
Standard arrays default to (-inf, inf).
Note that this function does not allow non-array/bounded leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.bounds.AbstractBounded, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model to extract bounds from. |
required |
Returns:
| Type | Description |
|---|---|
tuple[PyTree, PyTree]
|
A tuple of two PyTrees |
Source code in parax/bounds.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | |