1. Defining the model
Lets define a simple exponential decay model: \(y = A e^{-kt} + C\) using equinox and parax!
import jax.numpy as jnp
import equinox as eqx
import parax as prx
from parax.constraints import Positive
class DecayCurve(eqx.Module):
rate: prx.Param
baseline: prx.Param
amplitude: prx.Param = prx.constrained(Positive())
def __call__(self, t):
return self.amplitude * jnp.exp(-self.rate * t) + self.baseline
model = DecayCurve(
amplitude=5.0,
rate=prx.Tagged(0.5, metadata={'desc': 'Decay constant'}),
baseline=1.2
)
Note that 5.0 is automatically converted to prx.Constrained by the dataclass field.
For demonstration purposes, we fix the baseline in this example:
import dataclasses
model = dataclasses.replace(model, baseline=prx.Fixed(model.baseline))
2. Setting up the loss
Optimization libraries like optimistix expect standard JAX arrays. We need to split our model into trainable parameters and static metadata, and then re-combine during our forward pass.
By passing is_leaf=prx.is_constant to eqx.partition, we can also separate out all prx.Fixed variables (and nested prx.Frozen models) into the static half of the tree.
import jax
params, static = eqx.partition(model, eqx.is_inexact_array, is_leaf=prx.is_constant)
def loss_fn(params, args):
t, y_true = args
current_model = prx.unwrap(eqx.combine(params, static))
y_pred = jax.vmap(current_model)(t)
return jnp.mean((y_pred - y_true)**2)
parax.unwrap() recursively resolves any derived variables/PyTrees from the bottom up.
3. Running the optimizer
Finally, we generate some dummy data with amplitude=2.0 and rate=1.0 and let optimistix find the underlying parameters.
import optimistix as optx
t_data = jnp.linspace(0, 5, 100)
y_data = 2.0 * jnp.exp(-1.0 * t_data) + 1.2
solver = optx.BFGS(rtol=1e-5, atol=1e-5)
results = optx.minimise(loss_fn, solver, y0=params, args=(t_data, y_data))
final_model = prx.unwrap(eqx.combine(results.value, static))
Our optimized model matches our initial parameters:
final_model.amplitude
# 2.000002
final_model.rate
# 1.0000012