NelderMead

class pmrf.optimize.NelderMead(fatol: float = 1e-07, xatol: float = 1e-06, xrtol: float = 1e-06, norm: ~typing.Callable[[~jaxtyping.PyTree], ~jaxtyping.Shaped[jaxlib._jax.Array, '']] = <function max_norm>, rdelta: float = 0.05, adelta: float = 0.00025)

Bases: AbstractUnconstrainedMinimizer

A Nelder-Mead optimizer in JAX.

Wrapper around optimistix.NelderMead.

Parameters:
  • fatol (float, default=1e-7) – Absolute tolerance of the function value for termination.

  • xatol (float, default=1e-6) – Absolute tolerance of the simplex for termination.

  • xrtol (float, default=1e-6) – Relative tolerance of the simplex for termination.

  • norm (Callable[[PyTree], Scalar], default=optx.max_norm) – Norm function used to evaluate the error.

  • rdelta (float, default=5e-2) – Relative delta for the initial simplex.

  • adelta (float, default=2.5e-4) – Absolute delta for the initial simplex.

norm() Shaped[jaxlib._jax.Array, '']

Compute the L-infinity norm of a PyTree of arrays.

This is the largest absolute elementwise value. Considering the input x as a flat vector (x_1, …, x_n), then this computes max_i |x_i|.

run(fn: Callable[[PyTree, Any], Any], y0: PyTree, args: Any, max_iter: int, **kwargs) tuple[MinimizeResult, PyTree]

Execute the minimization algorithm.

Parameters:
  • fn (callable) – The objective function to minimize.

  • y0 (PyTree) – The initial parameter guess.

  • args (Any) – Args to pass to fn.

  • max_iter (int = 1024) – The maximum number of iterations to take.

  • **kwargs – Runtime arguments forward to the solver backend.

Returns:

A tuple of (pmrf.optimize.MinimizeResult, metrics)`.

Return type:

tuple

adelta: float = 0.00025
fatol: float = 1e-07
rdelta: float = 0.05
xatol: float = 1e-06
xrtol: float = 1e-06