Traversal
refrax.Traversal(tree, base_path, sub_paths)
Bases: Generic[TRoot]
Represents a multi-target focus within an immutable PyTree.
Applies mutations across all targets simultaneously using eqx.tree_at,
making it fully compatible with JAX JIT compilation and tracing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
TRoot
|
The root immutable object (e.g., Equinox module or dataclass). |
required |
base_path
|
list[PathStep]
|
The shared path from the root to the divergence point. |
required |
sub_paths
|
list[list[PathStep]]
|
A list of diverging paths, one for each targeted element. |
required |
Source code in refrax/traversal.py
21 22 23 24 | |
__getattr__(name)
Broadens the traversal by appending an attribute access to every currently focused target.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The name of the attribute to access on all targets. |
required |
Returns:
| Type | Description |
|---|---|
Traversal[TRoot]
|
Traversal[TRoot]: A new Traversal focused one level deeper. |
Source code in refrax/traversal.py
26 27 28 29 30 31 32 33 34 35 36 | |
__getitem__(key)
Broadens the traversal by appending an item/index access to every currently focused target.
Strings are treated as attributes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
Any
|
The index, dictionary key, or attribute string to access. |
required |
Returns:
| Type | Description |
|---|---|
Traversal[TRoot]
|
Traversal[TRoot]: A new Traversal focused one level deeper. |
Source code in refrax/traversal.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 | |
apply(func)
Applies a function to all selected targets simultaneously.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
func
|
Callable[[Any], Any]
|
The transformation function to apply to each target. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
TRoot |
TRoot
|
A new instance of the root tree with all targets updated. |
Source code in refrax/traversal.py
106 107 108 109 110 111 112 113 114 115 116 117 118 | |
exclude(predicate)
Filters the currently focused targets, dropping those that match the condition.
This is the logical inverse of .filter().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predicate
|
Callable[[Any], bool]
|
A function that returns True to drop a target. |
required |
Returns:
| Type | Description |
|---|---|
Traversal[TRoot]
|
Traversal[TRoot]: A new Traversal with the matching targets removed. |
Source code in refrax/traversal.py
179 180 181 182 183 184 185 186 187 188 189 190 | |
filter(predicate)
Filters the currently focused targets, keeping only those that match the condition.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
predicate
|
Callable[[Any], bool]
|
A function that returns True to keep a target, or False to drop it. |
required |
Returns:
| Type | Description |
|---|---|
Traversal[TRoot]
|
Traversal[TRoot]: A new Traversal focused only on the targets that passed the filter. |
Source code in refrax/traversal.py
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | |
get()
Extracts all focused values.
Returns:
| Type | Description |
|---|---|
list[Any]
|
list[Any]: A list containing the values of all currently focused targets. |
Source code in refrax/traversal.py
76 77 78 79 80 81 82 | |
path(target_path)
Broadens the traversal by appending a string path or JAX KeyPath to every currently focused target.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target_path
|
str | tuple
|
The path string (e.g., '.res.R') or native JAX KeyPath tuple. |
required |
Returns:
| Type | Description |
|---|---|
Traversal[TRoot]
|
Traversal[TRoot]: A new Traversal focused one level deeper across all branches. |
Source code in refrax/traversal.py
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | |
prune(*paths)
Filters out targets from the Traversal that intersect with the given paths.
Crucially, this prunes a target if it is exactly the excluded path, inside the excluded path, OR a parent of the excluded path. (Mutating a parent implicitly mutates its children, so parents of excluded paths must be pruned).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*paths
|
str | tuple
|
String paths or JAX KeyPath tuples to protect from mutation. |
()
|
Returns:
| Type | Description |
|---|---|
Traversal[TRoot]
|
Traversal[TRoot]: A new Traversal with the intersecting paths removed. |
Source code in refrax/traversal.py
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 | |
set(value)
Sets all selected targets to a specific value, or maps a sequence of values 1-to-1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Any
|
The new value to assign. If value is a list/tuple of the exact same length as the focused targets, they will be mapped 1-to-1. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
TRoot |
TRoot
|
A new instance of the root tree with the updated values. |
Source code in refrax/traversal.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | |