Source code for pmrf.optimize.minimize

from typing import Callable

import jax
import jax.numpy as jnp
import equinox as eqx
import optimistix as optx
import parax as prx
from distreqx.bijectors import AbstractBijector

from pmrf.core import Model, Frequency, Problem
from pmrf.optimize.result import OptimizeResult
from pmrf.optimize.solvers import ScipyMinimizer

[docs] def is_optimizer(x): """ Returns if a solver is suitable for frequentist optimization in :mod:`pmrf.optimize`. Returns ``True`` for ``optimistix.AbstractMinimiser`` and :class:``pmrf.optimize.ScipyMinimizer``. """ return isinstance(x, optx.AbstractMinimiser | ScipyMinimizer)
[docs] def minimize( cost_fn: Callable[[Model, Frequency], jnp.ndarray] | list[Callable], model: Model, frequency: Frequency, solver: optx.AbstractMinimiser | Callable = ScipyMinimizer(), *, transform: AbstractBijector | Callable[[prx.Parameter], AbstractBijector] | None = None, max_steps: int = 10000, **kwargs, ) -> OptimizeResult: """ Minimizes a given cost function for a model over a frequency range. The cost function can have its own hyper-parameters, and is returned in ``result.cost``. Parameters ---------- cost_fn : Callable[[Model, Frequency], jnp.ndarray] | list[Callable], The cost function to minimize. Must be a callable or PyTree with signature (model, freq) -> jnp.ndarray. If a list of costs is provided, they are automatically summed. See :meth:``pmrf.evaluators.Goal`` for an easy way to define goal-based cost functions. model : Model The RF model containing the parameters to be optimized. frequency : Frequency The frequency sweep over which the cost should be evaluated. solver : optx.AbstractMinimiser | Callable, default=ScipyMinimizer() The optimization backend to use. Defaults to the host-based SciPy L-BFGS-B. transform : distreqx.bijectors.AbstractBijector, default=None An invertible transformation to apply to all model parameters before optimization. max_steps : int, default=256 The maximum number of steps/iterations the underlying solver can take. **kwargs : dict Additional options passed to the underlying solver backend. Returns ------- OptimizeResult A structured result containing the fitted model and solver statistics. """ if isinstance(cost_fn, list): cost_fn = prx.op.Sum([c if isinstance(c, eqx.Module) else prx.op.Lambda(c) for c in cost_fn]) else: cost_fn = cost_fn if isinstance(cost_fn, eqx.Module) else prx.op.Lambda(cost_fn) problem = Problem(model=model, frequency=frequency, evaluator=cost_fn) if problem.num_flat_params == 0: raise Exception("Model has no free parameters to fit") # Helper to dynamically resolve the bijector for a given parameter def get_bijector(p: prx.Parameter) -> AbstractBijector: return transform(p) if callable(transform) else transform def apply_inverse(orig_x, trans_x): if isinstance(orig_x, prx.Parameter): from distreqx.bijectors import Inverse inv_bij = Inverse(get_bijector(orig_x)) return trans_x.transformed(inv_bij) return trans_x # 1. Apply the parameter space transform dynamically if transform is not None: transformed_problem = jax.tree.map( lambda x: x.transformed(get_bijector(x)) if isinstance(x, prx.Parameter) else x, problem, is_leaf=prx.is_free_param ) else: transformed_problem = problem transformed_params, transformed_static = prx.partition(transformed_problem) def obj_fn(transformed_params, _args): full_transformed = eqx.combine(transformed_params, transformed_static) # Map over both the source of truth (problem) and the solver state simultaneously if transform is not None: full_physical = jax.tree.map( apply_inverse, problem, full_transformed, is_leaf=prx.is_free_param ) else: full_physical = full_transformed return full_physical() # 2. Routing logic for bounding and solver execution if isinstance(solver, ScipyMinimizer): if 'bounds' in kwargs: lower_tree, upper_tree = kwargs.pop('bounds') else: def lower(x): if isinstance(x, prx.Parameter): if x.bounds is not None: low = x.bounds[..., 0] elif x.distribution is not None and hasattr(x.distribution, 'icdf'): low = x.distribution.icdf(0.01*jnp.ones_like(x.value)) else: low = jnp.full_like(x.value, -jnp.inf) return x.with_value(low) return x def upper(x): if isinstance(x, prx.Parameter): if x.bounds is not None: high = x.bounds[..., 1] elif x.distribution is not None and hasattr(x.distribution, 'icdf'): high = x.distribution.icdf(0.99*jnp.ones_like(x.value)) else: high = jnp.full_like(x.value, jnp.inf) return x.with_value(high) return x lower_tree = jax.tree.map(lower, problem, is_leaf=prx.is_free_param) upper_tree = jax.tree.map(upper, problem, is_leaf=prx.is_free_param) # Transform bounds BEFORE partitioning so the PyTree structure matches `problem` def apply_bound_transform(bound_val, orig_p): if isinstance(orig_p, prx.Parameter): return bound_val.transformed(get_bijector(orig_p)) return bound_val if transform is not None: transformed_lower_tree = jax.tree.map(apply_bound_transform, lower_tree, problem, is_leaf=prx.is_free_param) transformed_upper_tree = jax.tree.map(apply_bound_transform, upper_tree, problem, is_leaf=prx.is_free_param) else: transformed_lower_tree = lower_tree transformed_upper_tree = upper_tree # Now strip out static attributes (transformed_lower, transformed_upper), _ = prx.partition((transformed_lower_tree, transformed_upper_tree)) kwargs['bounds'] = (transformed_lower, transformed_upper) if kwargs.get('has_aux', False): raise Exception("Auxiliary data not supported for host solvers") kwargs['maxiter'] = max_steps solution = solver(obj_fn, transformed_params, args=None, options=kwargs) else: solution = optx.minimise(obj_fn, solver, transformed_params, max_steps=max_steps, **kwargs) # 3. Get the solved problem and reconstruct the physical state solved_transformed_problem = eqx.combine(solution.value, transformed_static) if transform is not None: solved_problem = jax.tree.map( apply_inverse, problem, solved_transformed_problem, is_leaf=prx.is_free_param ) else: solved_problem = solved_transformed_problem # 4. Standardize the results results = OptimizeResult( model=solved_problem.model, cost=solved_problem.evaluator, value=solved_problem(), stats=solution, ) if isinstance(solver, optx.AbstractMinimiser): print(f"Final cost = {results.value}") return results