HingeLoss

class pmrf.losses.HingeLoss(operator: Literal['<', '<=', '>', '>=', '==', '='] = '==', weight: float = 1.0, mask: Any = None, base_loss: str | Callable | AbstractLoss = RMSELoss(), multioutput: str | Array | Callable = 'uniform_average')

Bases: AbstractLoss

Applies a one-sided constraint (hinge) before evaluating a base metric.

Forwards to pmrf.math.losses.hinge_loss().

Parameters:
  • operator (Literal['<', '<=', '>', '>=', '==', '=']) – The logical constraint operator (‘<’, ‘<=’, ‘>’, ‘>=’, ‘==’, ‘=’).

  • weight (float) – A scalar or array multiplier to scale the importance of the penalty.

  • mask (jax.jaxlib._jax.Array | None) – A boolean array filtering which data points apply to this loss.

  • base_loss (str | Callable | pmrf.losses.AbstractLoss) – The underlying loss function.

  • multioutput (str | jax.jaxlib._jax.Array | Callable) – Defines the aggregation strategy across multiple output dimensions.

__call__(y_true: Array, y_pred: Array, **kwargs) Array

Compute the loss between true data and model predictions.

Parameters:
  • y_true (jnp.ndarray) – The observed ground-truth data.

  • y_pred (jnp.ndarray) – The model’s predicted data.

  • **kwargs (dict) – Additional keyword arguments for loss computation.

Returns:

The calculated loss value.

Return type:

jnp.ndarray

base_loss: str | Callable | AbstractLoss = RMSELoss()

The underlying loss function.

mask: Array | None = None

A boolean array filtering which data points apply to this loss.

multioutput: str | Array | Callable = 'uniform_average'

Defines the aggregation strategy across multiple output dimensions.

operator: Literal['<', '<=', '>', '>=', '==', '='] = '=='

The logical constraint operator (‘<’, ‘>’, ‘==’, etc.).

weight: float = 1.0

A scalar or array multiplier to scale the importance of the penalty.