LBFGS
- class pmrf.optimize.LBFGS(fatol: float = 1e-07, step_atol: float = 1e-06, step_rtol: float = 1e-06, norm: ~typing.Callable[[~jaxtyping.PyTree], ~jaxtyping.Shaped[jaxlib._jax.Array, '']] = <function max_norm>, use_inverse: bool = True)
Bases:
AbstractUnconstrainedMinimizerA LBFGS optimizer in JAX.
Wrapper around
optimistix.LBFGS.- Parameters:
fatol (float, default=1e-7) – Absolute tolerance of the function value for termination.
step_atol (float, default=1e-6) – Absolute tolerance of the gradients and step sizes for termination.
step_rtol (float, default=1e-6) – Relative tolerance of the gradients and step sizes for termination.
norm (Callable[[PyTree], Scalar], default=optx.max_norm) – Norm function used to evaluate the error.
use_inverse (bool, default=True) – Whether to use the inverse Hessian approximation.
- norm() Shaped[jaxlib._jax.Array, '']
Compute the L-infinity norm of a PyTree of arrays.
This is the largest absolute elementwise value. Considering the input x as a flat vector (x_1, …, x_n), then this computes max_i |x_i|.
- run(fn: Callable[[PyTree, Any], Any], y0: PyTree, args: Any, max_iter: int, **kwargs) tuple[MinimizeResult, PyTree]
Execute the minimization algorithm.
- Parameters:
fn (callable) – The objective function to minimize.
y0 (PyTree) – The initial parameter guess.
args (Any) – Args to pass to fn.
max_iter (int = 1024) – The maximum number of iterations to take.
**kwargs – Runtime arguments forward to the solver backend.
- Returns:
A tuple of (
pmrf.optimize.MinimizeResult, metrics)`.- Return type:
tuple
- fatol: float = 1e-07
- step_atol: float = 1e-06
- step_rtol: float = 1e-06
- use_inverse: bool = True