from typing import Callable
import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
import parax as prx
from distreqx.bijectors import AbstractBijector
from pmrf.core import Model, Frequency, Problem
from pmrf.optimize.result import OptimizeResult
from pmrf.optimize.solvers import ScipyMinimizer
[docs]
def is_optimizer(x):
"""
Returns if a solver is suitable for frequentist optimization in :mod:`pmrf.optimize`.
Returns ``True`` for ``optimistix.AbstractMinimiser`` and :class:``pmrf.optimize.ScipyMinimizer``.
"""
return isinstance(x, optx.AbstractMinimiser | ScipyMinimizer)
[docs]
def minimize(
cost_fn: Callable[[Model, Frequency], jnp.ndarray] | list[Callable],
model: Model,
frequency: Frequency,
solver: optx.AbstractMinimiser | Callable = ScipyMinimizer(),
*,
transform: AbstractBijector | Callable[[prx.Parameter], AbstractBijector] | None = None,
max_steps: int = 10000,
**kwargs,
) -> OptimizeResult:
"""
Minimizes a given cost function for a model over a frequency range.
The cost function can have its own hyper-parameters, and is returned in ``result.cost``.
Parameters
----------
cost_fn : Callable[[Model, Frequency], jnp.ndarray] | list[Callable],
The cost function to minimize. Must be a callable or PyTree with signature
(model, freq) -> jnp.ndarray. If a list of costs is provided, they are automatically summed.
See :meth:``pmrf.evaluators.Goal`` for an easy way to define goal-based cost functions.
model : Model
The RF model containing the parameters to be optimized.
frequency : Frequency
The frequency sweep over which the cost should be evaluated.
solver : optx.AbstractMinimiser | Callable, default=ScipyMinimizer()
The optimization backend to use. Defaults to the host-based SciPy L-BFGS-B.
transform : distreqx.bijectors.AbstractBijector, default=None
An invertible transformation to apply to all model parameters before optimization.
max_steps : int, default=256
The maximum number of steps/iterations the underlying solver can take.
**kwargs : dict
Additional options passed to the underlying solver backend.
Returns
-------
OptimizeResult
A structured result containing the fitted model and solver statistics.
"""
if isinstance(cost_fn, list):
cost_fn = prx.op.Sum([c if isinstance(c, eqx.Module) else prx.op.Lambda(c) for c in cost_fn])
else:
cost_fn = cost_fn if isinstance(cost_fn, eqx.Module) else prx.op.Lambda(cost_fn)
problem = Problem(model=model, frequency=frequency, evaluator=cost_fn)
if problem.num_flat_params == 0:
raise Exception("Model has no free parameters to fit")
# Helper to dynamically resolve the bijector for a given parameter
def get_bijector(p: prx.Parameter) -> AbstractBijector:
return transform(p) if callable(transform) else transform
def apply_inverse(orig_x, trans_x):
if isinstance(orig_x, prx.Parameter):
from distreqx.bijectors import Inverse
inv_bij = Inverse(get_bijector(orig_x))
return trans_x.transformed(inv_bij)
return trans_x
# 1. Apply the parameter space transform dynamically
if transform is not None:
transformed_problem = jax.tree.map(
lambda x: x.transformed(get_bijector(x)) if isinstance(x, prx.Parameter) else x,
problem,
is_leaf=prx.is_free_param
)
else:
transformed_problem = problem
transformed_params, transformed_static = prx.partition(transformed_problem)
def obj_fn(transformed_params, _args):
full_transformed = eqx.combine(transformed_params, transformed_static)
# Map over both the source of truth (problem) and the solver state simultaneously
if transform is not None:
full_physical = jax.tree.map(
apply_inverse,
problem,
full_transformed,
is_leaf=prx.is_free_param
)
else:
full_physical = full_transformed
return full_physical()
# 2. Routing logic for bounding and solver execution
if isinstance(solver, ScipyMinimizer):
if 'bounds' in kwargs:
lower_tree, upper_tree = kwargs.pop('bounds')
else:
def lower(x):
if isinstance(x, prx.Parameter):
if x.bounds is not None:
low = x.bounds[..., 0]
elif x.distribution is not None and hasattr(x.distribution, 'icdf'):
low = x.distribution.icdf(0.01*jnp.ones_like(x.value))
else:
low = jnp.full_like(x.value, -jnp.inf)
return x.with_value(low)
return x
def upper(x):
if isinstance(x, prx.Parameter):
if x.bounds is not None:
high = x.bounds[..., 1]
elif x.distribution is not None and hasattr(x.distribution, 'icdf'):
high = x.distribution.icdf(0.99*jnp.ones_like(x.value))
else:
high = jnp.full_like(x.value, jnp.inf)
return x.with_value(high)
return x
lower_tree = jax.tree.map(lower, problem, is_leaf=prx.is_free_param)
upper_tree = jax.tree.map(upper, problem, is_leaf=prx.is_free_param)
# Transform bounds BEFORE partitioning so the PyTree structure matches `problem`
def apply_bound_transform(bound_val, orig_p):
if isinstance(orig_p, prx.Parameter):
return bound_val.transformed(get_bijector(orig_p))
return bound_val
if transform is not None:
transformed_lower_tree = jax.tree.map(apply_bound_transform, lower_tree, problem, is_leaf=prx.is_free_param)
transformed_upper_tree = jax.tree.map(apply_bound_transform, upper_tree, problem, is_leaf=prx.is_free_param)
else:
transformed_lower_tree = lower_tree
transformed_upper_tree = upper_tree
# Now strip out static attributes
(transformed_lower, transformed_upper), _ = prx.partition((transformed_lower_tree, transformed_upper_tree))
kwargs['bounds'] = (transformed_lower, transformed_upper)
if kwargs.get('has_aux', False):
raise Exception("Auxiliary data not supported for host solvers")
kwargs['maxiter'] = max_steps
solution = solver(obj_fn, transformed_params, args=None, options=kwargs)
else:
solution = optx.minimise(obj_fn, solver, transformed_params, max_steps=max_steps, **kwargs)
# 3. Get the solved problem and reconstruct the physical state
solved_transformed_problem = eqx.combine(solution.value, transformed_static)
if transform is not None:
solved_problem = jax.tree.map(
apply_inverse,
problem,
solved_transformed_problem,
is_leaf=prx.is_free_param
)
else:
solved_problem = solved_transformed_problem
# 4. Standardize the results
results = OptimizeResult(
model=solved_problem.model,
cost=solved_problem.evaluator,
value=solved_problem(),
stats=solution,
)
if isinstance(solver, optx.AbstractMinimiser):
print(f"Final cost = {results.value}")
return results