RMSELoss

class pmrf.losses.RMSELoss(multioutput: str | Array | Callable = 'uniform_average')

Bases: AbstractLoss

Root Mean Squared Error (RMSE) metric.

Forwards to pmrf.math.losses.root_mean_squared_error().

Parameters:

multioutput (str | jax.jaxlib._jax.Array | Callable) – Defines the aggregation strategy across multiple output dimensions.

__call__(y_true: Array, y_pred: Array, **kwargs) Array

Compute the loss between true data and model predictions.

Parameters:
  • y_true (jnp.ndarray) – The observed ground-truth data.

  • y_pred (jnp.ndarray) – The model’s predicted data.

  • **kwargs (dict) – Additional keyword arguments for loss computation.

Returns:

The calculated loss value.

Return type:

jnp.ndarray

multioutput: str | Array | Callable = 'uniform_average'

Defines the aggregation strategy across multiple output dimensions.