RMSELoss
- class pmrf.losses.RMSELoss(multioutput: str | Array | Callable = 'uniform_average')
Bases:
AbstractLossRoot Mean Squared Error (RMSE) metric.
Forwards to
pmrf.math.losses.root_mean_squared_error().- Parameters:
multioutput (str | jax.jaxlib._jax.Array | Callable) – Defines the aggregation strategy across multiple output dimensions.
- __call__(y_true: Array, y_pred: Array, **kwargs) Array
Compute the loss between true data and model predictions.
- Parameters:
y_true (jnp.ndarray) – The observed ground-truth data.
y_pred (jnp.ndarray) – The model’s predicted data.
**kwargs (dict) – Additional keyword arguments for loss computation.
- Returns:
The calculated loss value.
- Return type:
jnp.ndarray
- multioutput: str | Array | Callable = 'uniform_average'
Defines the aggregation strategy across multiple output dimensions.