Refrax
Refrax is a small, pragmatic library implementing the optics pattern for JAX PyTrees.
It focuses on elegant syntax and composable conditional chains to provide easy yet powerful PyTree manipulation.
Installation
Refrax can be installed using pip:
pip install refrax
Quick example
import equinox as eqx
import jax
class Model(eqx.Module):
core: eqx.nn.Linear
head: eqx.nn.Linear
dropout: float
key1, key2 = jax.random.split(jax.random.key(0))
model = Model(
core=eqx.nn.Linear(in_features=5, out_features=5, key=key1),
head=eqx.nn.Linear(in_features=5, out_features=2, key=key2),
dropout=0.5
)
Then we can do updates using focus:
from refrax import focus
model = focus(model).dropout.set(0.1)
model = focus(model).select("core", "head").bias.apply(lambda b: b + 1.0)
model.dropout
# 0.1
model.core.bias
# [0.646, 0.860, 0.670 , 1.277 , 0.727]
model.head.bias
# [1.634, 1.877]
One of the useful optics methods is simply traversing over a PyTree's leaves:
len(focus(model).leaves().get())
# 5
Documentation
Documentation is available here.
Related
The library uses Equinox (specifically eqx.tree_at) under the hood.