GibbsMarginalLogLikelihood

class pmrf.evaluators.GibbsMarginalLogLikelihood(predictor: Callable[[Model, Frequency], jnp.ndarray], observed: Any, loss: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], temperature: float = 1.0, discrepancy: Callable[[jnp.ndarray, jnp.ndarray], dist.AbstractDistribution] | None = None, use_orthogonal_discrepancy: bool = False, event_transform: bij.AbstractBijector = None, event_ndims: int = 1)

Bases: AbstractEvaluator

Computes a generalized log-posterior (the Gibbs measure) using a loss function instead of a strict generative likelihood.

This is used for Generalized Bayesian Inference (GBI). It supports conditioning a physical model prediction while marginalizing out a potential discrepancy model using an Expected Loss framework.

Parameters:
  • predictor (Callable[[pmrf.models.base.Model, pmrf.frequency.Frequency], jax.jaxlib._jax.Array]) – The predictor (e.g. another Evaluator) that extracts model features. Can be a function or a PyTree with optional parameters.

  • observed (jax.jaxlib._jax.Array) – The observed data that the loss will be computed against. Must have a shape that matches the shape of the predictor output.

  • loss (Callable[[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], jax.jaxlib._jax.Array]) – The loss function that takes (y_true, y_pred) and returns a loss metric. Can be a function or a PyTree with optional parameters.

  • temperature (float) – The inverse-weight (temperature) of the Gibbs measure. Higher temperatures create wider, less confident posteriors.

  • discrepancy (Callable[[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], distreqx.distributions._distribution.AbstractDistribution] | None) – (experimental) An optional discrepancy model to cater for model misspecification.

  • use_orthogonal_discrepancy (bool) – (experimental) Whether or not the discrepancy callable accepts a key-word argument “orthogonal_projection” which defines the model’s orthogonal sub-space. Used for gaussian processes.

  • event_transform (distreqx.bijectors._bijector.AbstractBijector) – A bijective transform that maps from “observation space” (predicted features) to “event space”.

  • event_ndims (int) – The number of trailing event dimensions in event space to use as the event shape. Defaults to 1.

__call__(model: Model, frequency: Frequency, **kwargs) Array

Evaluate the model response over the specified frequency range.

Parameters:
  • model (Model) – The model instance to evaluate.

  • freq (Frequency) – The frequency object defining the evaluation points.

  • **kwargs (dict) – Additional keyword arguments for the evaluation process.

Returns:

The evaluated model response.

Return type:

jnp.ndarray

discrepancy: Callable[[Array, Array], AbstractDistribution] | None = None

The active discrepancy model.

event_ndims: int = 1

The number of event dimensions.

event_transform: AbstractBijector = None

The bijective event transform.

loss: Callable[[Array, Array], Array]

The active loss function.

observed: Array

The observed data.

predictor: Callable[[Model, Frequency], Array]

The active predictor instance.

temperature: float = 1.0

The Gibbs measure temperature.

use_orthogonal_discrepancy: bool = False

Flag for orthogonal discrepancy.