Overview
In this example, we optimize a constrained exponential decay model defining using equinox using optimistix.
1. Defining the model
First, we define a simple exponential decay model \(y = A e^{-kt} + C\) using eqx.Module:
import jax.numpy as jnp
import equinox as eqx
import parax as prx
from parax.constraints import Positive
class DecayCurve(eqx.Module):
rate: prx.Param
baseline: prx.Param = eqx.field(converter=jnp.asarray)
amplitude: prx.Param = eqx.field(converter=lambda x: prx.Constrained(Positive(), x))
def __call__(self, t):
return self.amplitude * jnp.exp(-self.rate * t) + self.baseline
model = DecayCurve(
amplitude=5.0,
rate=prx.Tagged(0.5, metadata={'desc': 'Decay constant'}),
baseline=1.2,
)
prx.Param is simply a type-hint for JAX arrays or built-in Parax variables. Notice also how we can specifier converters using equinox.field to enforce constraint or automatically converter fields to arrays. Since all constraints in parax.constraints are bijective, prx.Constrained accepts the constrained value for construction.
For demonstration purposes, we fix the baseline in this example:
import dataclasses
model = dataclasses.replace(model, baseline=prx.Fixed(model.baseline))
2. Setting up the loss
Optimization libraries like optimistix expect standard JAX arrays. We need to split our model into trainable parameters and static metadata, and then re-combine during our forward pass.
By passing is_leaf=prx.is_constant to eqx.partition, we can also separate out all prx.Fixed variables (and nested prx.Freeze models) into the static half of the tree.
import jax
params, static = eqx.partition(model, eqx.is_inexact_array, is_leaf=prx.is_constant)
def loss_fn(params, args):
t, y_true = args
current_model = prx.unwrap(eqx.combine(params, static))
y_pred = jax.vmap(current_model)(t)
return jnp.mean((y_pred - y_true)**2)
parax.unwrap() recursively resolves any derived variables/PyTrees from the bottom up.
3. Running the optimizer
Finally, we generate some dummy data with amplitude=2.0 and rate=1.0 and let optimistix find the underlying parameters.
import optimistix as optx
t_data = jnp.linspace(0, 5, 100)
y_data = 2.0 * jnp.exp(-1.0 * t_data) + 1.2
solver = optx.BFGS(rtol=1e-5, atol=1e-5)
results = optx.minimise(loss_fn, solver, y0=params, args=(t_data, y_data))
final_model = prx.unwrap(eqx.combine(results.value, static))
Our optimized model matches our initial parameters:
final_model.amplitude
# 2.000002
final_model.rate
# 1.0000012