GibbsMarginalLogLikelihood
- class pmrf.evaluators.GibbsMarginalLogLikelihood(predictor: Callable[[Model, Frequency], jnp.ndarray], observed: Any, loss: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], temperature: float = 1.0, discrepancy: Callable[[jnp.ndarray, jnp.ndarray], dist.AbstractDistribution] | None = None, use_orthogonal_discrepancy: bool = False, event_transform: bij.AbstractBijector = None, event_ndims: int = 1)
Bases:
AbstractEvaluatorComputes a generalized log-posterior (the Gibbs measure) using a loss function instead of a strict generative likelihood.
This is used for Generalized Bayesian Inference (GBI). It supports conditioning a physical model prediction while marginalizing out a potential discrepancy model using an Expected Loss framework.
- Parameters:
predictor (Callable[[pmrf.models.base.Model, pmrf.frequency.Frequency], jax.jaxlib._jax.Array]) – The predictor (e.g. another Evaluator) that extracts model features. Can be a function or a PyTree with optional parameters.
observed (jax.jaxlib._jax.Array) – The observed data that the loss will be computed against. Must have a shape that matches the shape of the predictor output.
loss (Callable[[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], jax.jaxlib._jax.Array]) – The loss function that takes (y_true, y_pred) and returns a loss metric. Can be a function or a PyTree with optional parameters.
temperature (float) – The inverse-weight (temperature) of the Gibbs measure. Higher temperatures create wider, less confident posteriors.
discrepancy (Callable[[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], distreqx.distributions._distribution.AbstractDistribution] | None) – (experimental) An optional discrepancy model to cater for model misspecification.
use_orthogonal_discrepancy (bool) – (experimental) Whether or not the discrepancy callable accepts a key-word argument “orthogonal_projection” which defines the model’s orthogonal sub-space. Used for gaussian processes.
event_transform (distreqx.bijectors._bijector.AbstractBijector) – A bijective transform that maps from “observation space” (predicted features) to “event space”.
event_ndims (int) – The number of trailing event dimensions in event space to use as the event shape. Defaults to 1.
- __call__(model: Model, frequency: Frequency, **kwargs) Array
Evaluate the model response over the specified frequency range.
- discrepancy: Callable[[Array, Array], AbstractDistribution] | None = None
The active discrepancy model.
- event_ndims: int = 1
The number of event dimensions.
- event_transform: AbstractBijector = None
The bijective event transform.
- loss: Callable[[Array, Array], Array]
The active loss function.
- observed: Array
The observed data.
- temperature: float = 1.0
The Gibbs measure temperature.
- use_orthogonal_discrepancy: bool = False
Flag for orthogonal discrepancy.