weighted_sum

pmrf.math.aggregations.weighted_sum(x: Array, weights: Array | None = None) Array

Reduces the value over the sample (batch) dimension, applying weights if provided.

Parameters:
  • x (jnp.ndarray) – The raw input array of shape (n_samples, …).

  • weights (jnp.ndarray, optional) – Optional array of weights for each sample, shape (n_samples,).

Returns:

A JAX array of shape (…,) containing the sample-reduced loss.

Return type:

jnp.ndarray