HuberLoss

class pmrf.losses.HuberLoss(delta: float = 1.0, multioutput: str | Array | Callable = 'uniform_average')

Bases: AbstractLoss

Huber loss metric.

A robust loss function that transitions from squared error to absolute error depending on the delta threshold.

Forwards to pmrf.math.losses.huber_loss().

Parameters:
  • delta (float) – The threshold at which to change between squared error and absolute error.

  • multioutput (str | jax.jaxlib._jax.Array | Callable) – Defines the aggregation strategy across multiple output dimensions.

__call__(y_true: Array, y_pred: Array, **kwargs) Array

Compute the loss between true data and model predictions.

Parameters:
  • y_true (jnp.ndarray) – The observed ground-truth data.

  • y_pred (jnp.ndarray) – The model’s predicted data.

  • **kwargs (dict) – Additional keyword arguments for loss computation.

Returns:

The calculated loss value.

Return type:

jnp.ndarray

delta: float = 1.0

The threshold at which to change between squared error and absolute error.

multioutput: str | Array | Callable = 'uniform_average'

Defines the aggregation strategy across multiple output dimensions.