aggregate

pmrf.math.aggregations.aggregate(mean_loss: Array, alias: str | Array | Callable = 'uniform_average') Array

General aggregation of across multiple output dimensions given an alias string or an input callable.

Parameters:
  • mean_loss (jnp.ndarray) – The loss array reduced over the batch dimension, shape (n_outputs,).

  • alias (str, jnp.ndarray, or Callable, default='uniform_average') – String alias (‘raw_values’, ‘uniform_average’, ‘geometric_mean’, ‘convolution’, or ‘log_mean’), an array of custom weights for each output, or a custom callable function.

Returns:

The fully aggregated loss as a scalar (or array if ‘raw_values’ is selected).

Return type:

jnp.ndarray