Skip to content

Lens

refrax.focus(tree)

Returns a lens focused on self.

Source code in refrax/lens.py
225
226
227
228
229
def focus(tree: TRoot) -> Lens[TRoot]:
    """
    Returns a lens focused on self.
    """
    return Lens(tree)

refrax.Lens(tree, path=None)

Bases: Generic[TRoot]

A fluent interface for mutating immutable Equinox PyTrees.

Uses eqx.tree_at under the hood to functionally swap leaves in the PyTree, making it completely safe for use inside JAX JIT/vmap boundaries.

Parameters:

Name Type Description Default
tree TRoot

The root immutable object to be mutated.

required
path list[PathStep] | None

The current traversal path from the root, by default None.

None
Source code in refrax/lens.py
24
25
26
def __init__(self, tree: TRoot, path: list[PathStep] | None = None) -> None:
    self._tree = tree
    self._path = path if path is not None else []

__getattr__(name)

Focuses the lens on a named attribute.

Parameters:

Name Type Description Default
name str

The name of the attribute.

required

Returns:

Type Description
Lens[TRoot]

Lens[TRoot]: A new Lens focused on the specified attribute.

Source code in refrax/lens.py
28
29
30
31
32
33
34
35
36
37
38
39
def __getattr__(self, name: str) -> "Lens[TRoot]":
    """Focuses the lens on a named attribute.

    Args:
        name (str): The name of the attribute.

    Returns:
        Lens[TRoot]: A new Lens focused on the specified attribute.
    """
    if name.startswith('__') and name.endswith('__'):
        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
    return Lens(self._tree, self._path + [("attr", name)])

__getitem__(key)

Focuses the lens on a collection item, dictionary key, or attribute.

If the key is a string, it is treated as an attribute access (equivalent to getattr). Otherwise, it is treated as an item access.

Parameters:

Name Type Description Default
key Any

The index, key, or attribute name.

required

Returns:

Type Description
Lens[TRoot]

Lens[TRoot]: A new Lens focused on the specified item or attribute.

Source code in refrax/lens.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __getitem__(self, key: Any) -> "Lens[TRoot]":
    """Focuses the lens on a collection item, dictionary key, or attribute.

    If the key is a string, it is treated as an attribute access 
    (equivalent to `getattr`). Otherwise, it is treated as an item access.

    Args:
        key (Any): The index, key, or attribute name.

    Returns:
        Lens[TRoot]: A new Lens focused on the specified item or attribute.
    """
    op: PathOp = "attr" if isinstance(key, str) else "item"
    return Lens(self._tree, self._path + [(op, key)])

apply(func)

Applies a transformation function to the focused target.

Parameters:

Name Type Description Default
func Callable[[Any], Any]

The function to transform the current value.

required

Returns:

Name Type Description
TRoot TRoot

A new instance of the root tree with the updated value.

Source code in refrax/lens.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def apply(self, func: Callable[[Any], Any]) -> TRoot:
    """Applies a transformation function to the focused target.

    Args:
        func (Callable[[Any], Any]): The function to transform the current value.

    Returns:
        TRoot: A new instance of the root tree with the updated value.
    """
    if not self._path:
        return cast(TRoot, func(self._tree))

    return eqx.tree_at(self._get_target_from, self._tree, replace_fn=func)    

each()

Transforms a focus on a collection into a Traversal of its elements.

Returns:

Type Description
Traversal[TRoot]

Traversal[TRoot]: A Traversal object focused on every item in the target collection.

Raises:

Type Description
TypeError

If the current focus is not a dictionary, list, or tuple.

Examples:

>>> new_model = focus(model).sources.each().apply(lambda x: x * 2)
Source code in refrax/lens.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def each(self) -> Traversal[TRoot]:
    """Transforms a focus on a collection into a Traversal of its elements.

    Returns:
        Traversal[TRoot]: A Traversal object focused on every item in the target collection.

    Raises:
        TypeError: If the current focus is not a dictionary, list, or tuple.

    Examples:
        >>> new_model = focus(model).sources.each().apply(lambda x: x * 2)
    """
    target = self.get()
    sub_paths: list[list[PathStep]] = []

    if isinstance(target, dict):
        sub_paths = [[("item", key)] for key in target.keys()]
    elif isinstance(target, (list, tuple)):
        sub_paths = [[("item", i)] for i in range(len(target))]
    else:
        raise TypeError(f"Cannot iterate over {type(target).__name__} with .each()")

    return Traversal(self._tree, self._path, sub_paths)

get()

Extracts the currently focused value.

Returns:

Name Type Description
Any Any

The value at the end of the Lens path.

Source code in refrax/lens.py
75
76
77
78
79
80
81
def get(self) -> Any:
    """Extracts the currently focused value.

    Returns:
        Any: The value at the end of the Lens path.
    """
    return self._get_target_from(self._tree)

leaves(is_leaf=None)

Instantly branches the Lens into a Traversal of all leaf nodes (arrays/scalars).

Powered by JAX's C++ backend (tree_leaves_with_path), making it lightning fast.

Source code in refrax/lens.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def leaves(self, is_leaf: Callable[[Any], bool] | None = None) -> Traversal[TRoot]:
    """Instantly branches the Lens into a Traversal of all leaf nodes (arrays/scalars).

    Powered by JAX's C++ backend (tree_leaves_with_path), making it lightning fast.
    """
    target = self.get()
    leaves_with_paths = jtu.tree_leaves_with_path(target, is_leaf=is_leaf)

    sub_paths = []
    for jax_path, _val in leaves_with_paths:
        relative_path = translate_jax_path(jax_path)
        sub_paths.append(relative_path)

    return Traversal(self._tree, self._path, sub_paths)

path(target_path)

Advances the Lens focus based on a string path or JAX KeyPath tuple.

Parameters:

Name Type Description Default
target_path str | tuple

The path string (e.g., '.res.R') or native JAX KeyPath tuple.

required

Returns:

Type Description
Lens[TRoot]

Lens[TRoot]: A new Lens focused on the parsed path.

Source code in refrax/lens.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def path(self, target_path: str | tuple) -> "Lens[TRoot]":
    """Advances the Lens focus based on a string path or JAX KeyPath tuple.

    Args:
        target_path (str | tuple): The path string (e.g., '.res.R') or native 
            JAX KeyPath tuple.

    Returns:
        Lens[TRoot]: A new Lens focused on the parsed path.
    """
    if isinstance(target_path, str):
        parsed_steps = parse_string_path(target_path)
    elif isinstance(target_path, tuple):
        parsed_steps = translate_jax_path(target_path)
    else:
        raise TypeError(f"Expected string or JAX path tuple, got {type(target_path).__name__}")

    return Lens(self._tree, self._path + parsed_steps)    

select(*paths)

Branches the Lens into a Traversal targeting multiple attributes.

Supports JAX-style string paths (e.g., '.a[0].b') OR native JAX KeyPath tuples returned by jax.tree_util.

Parameters:

Name Type Description Default
*paths str | tuple

A variable number of attribute paths to target.

()

Returns:

Type Description
Traversal[TRoot]

Traversal[TRoot]: A Traversal object focused on the specified attributes.

Source code in refrax/lens.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def select(self, *paths: str | tuple) -> Traversal[TRoot]:
    """Branches the Lens into a Traversal targeting multiple attributes.

    Supports JAX-style string paths (e.g., '.a[0].b') OR native JAX 
    KeyPath tuples returned by `jax.tree_util`.

    Args:
        *paths (str | tuple): A variable number of attribute paths to target.

    Returns:
        Traversal[TRoot]: A Traversal object focused on the specified attributes.
    """
    sub_paths: list[list[PathStep]] = []
    for p in paths:
        if isinstance(p, str):
            sub_paths.append(parse_string_path(p))
        elif isinstance(p, tuple):
            sub_paths.append(translate_jax_path(p))
        else:
            raise TypeError(f"Expected string or JAX path tuple, got {type(p).__name__}")

    return Traversal(self._tree, self._path, sub_paths)

set(value)

Sets the focused target to a specific value.

Parameters:

Name Type Description Default
value Any

The new value to assign.

required

Returns:

Name Type Description
TRoot TRoot

A new instance of the root tree with the updated value.

Source code in refrax/lens.py
83
84
85
86
87
88
89
90
91
92
93
94
95
def set(self, value: Any) -> TRoot:
    """Sets the focused target to a specific value.

    Args:
        value (Any): The new value to assign.

    Returns:
        TRoot: A new instance of the root tree with the updated value.
    """
    if not self._path:
        return cast(TRoot, value)

    return eqx.tree_at(self._get_target_from, self._tree, replace=value)

where(predicate)

Traverses immediate items/attributes of the current focus that match a condition.

Parameters:

Name Type Description Default
predicate Callable[[Any], bool]

Returns True if the immediate child should be included in the Traversal.

required

Returns:

Type Description
Traversal[TRoot]

Traversal[TRoot]: A Traversal focused on matching immediate children.

Source code in refrax/lens.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def where(self, predicate: Callable[[Any], bool]) -> Traversal[TRoot]:
    """Traverses immediate items/attributes of the current focus that match a condition.

    Args:
        predicate (Callable[[Any], bool]): Returns True if the immediate child 
            should be included in the Traversal.

    Returns:
        Traversal[TRoot]: A Traversal focused on matching immediate children.
    """
    target = self.get()
    sub_paths: list[list[PathStep]] = []

    # 1. Standard Python Collections
    if isinstance(target, dict):
        for k, val in target.items():
            if predicate(val):
                sub_paths.append([("item", k)])

    elif isinstance(target, (list, tuple)):
        for i, val in enumerate(target):
            if predicate(val):
                sub_paths.append([("item", i)])

    elif hasattr(target, '__dict__'):
        for attr_name, val in vars(target).items():
            if not attr_name.startswith('_'):
                if predicate(val):
                    sub_paths.append([("attr", attr_name)])

    return Traversal(self._tree, self._path, sub_paths)