weighted_sum
- pmrf.math.aggregations.weighted_sum(x: Array, weights: Array | None = None) Array
Reduces the value over the sample (batch) dimension, applying weights if provided.
- Parameters:
x (jnp.ndarray) – The raw input array of shape (n_samples, …).
weights (jnp.ndarray, optional) – Optional array of weights for each sample, shape (n_samples,).
- Returns:
A JAX array of shape (…,) containing the sample-reduced loss.
- Return type:
jnp.ndarray