HingeLoss
- class pmrf.losses.HingeLoss(operator: Literal['<', '<=', '>', '>=', '==', '='] = '==', weight: float = 1.0, mask: Any = None, base_loss: str | Callable | AbstractLoss = RMSELoss(), multioutput: str | Array | Callable = 'uniform_average')
Bases:
AbstractLossApplies a one-sided constraint (hinge) before evaluating a base metric.
Forwards to
pmrf.math.losses.hinge_loss().- Parameters:
operator (Literal['<', '<=', '>', '>=', '==', '=']) – The logical constraint operator (‘<’, ‘<=’, ‘>’, ‘>=’, ‘==’, ‘=’).
weight (float) – A scalar or array multiplier to scale the importance of the penalty.
mask (jax.jaxlib._jax.Array | None) – A boolean array filtering which data points apply to this loss.
base_loss (str | Callable | pmrf.losses.AbstractLoss) – The underlying loss function.
multioutput (str | jax.jaxlib._jax.Array | Callable) – Defines the aggregation strategy across multiple output dimensions.
- __call__(y_true: Array, y_pred: Array, **kwargs) Array
Compute the loss between true data and model predictions.
- Parameters:
y_true (jnp.ndarray) – The observed ground-truth data.
y_pred (jnp.ndarray) – The model’s predicted data.
**kwargs (dict) – Additional keyword arguments for loss computation.
- Returns:
The calculated loss value.
- Return type:
jnp.ndarray
- base_loss: str | Callable | AbstractLoss = RMSELoss()
The underlying loss function.
- mask: Array | None = None
A boolean array filtering which data points apply to this loss.
- multioutput: str | Array | Callable = 'uniform_average'
Defines the aggregation strategy across multiple output dimensions.
- operator: Literal['<', '<=', '>', '>=', '==', '='] = '=='
The logical constraint operator (‘<’, ‘>’, ‘==’, etc.).
- weight: float = 1.0
A scalar or array multiplier to scale the importance of the penalty.