MarginalLogLikelihood

class pmrf.evaluators.MarginalLogLikelihood(predictor: Callable[[Model, Frequency], jnp.ndarray], observed: Any, likelihood: Callable[[jnp.ndarray | dist.AbstractDistribution], dist.AbstractDistribution], discrepancy: Callable[[jnp.ndarray, jnp.ndarray], dist.AbstractDistribution] | None = None, use_orthogonal_discrepancy: bool = False, event_transform: bij.AbstractBijector = None, event_ndims: int = 1)

Bases: AbstractEvaluator

Computes the log of the probability of observed data by conditioning a likelihood function on a model prediction while marginalizing out a potential discrepancy model.

Includes a mapping from model “observation space” to “event space”. By default, this defines the frequency axis as the probabilistic event, by moving it to the last axis before passing it to the likelihood/discrepancy. However, an bijective transform can be applied to model probability in an arbitrary latent space.

Parameters:
  • predictor (Callable[[pmrf.models.base.Model, pmrf.frequency.Frequency], jax.jaxlib._jax.Array]) – The predictor (e.g. another Evaluator) that extracts model features. Can be a function or a PyTree with optional parameters.

  • observed (numpy.ndarray) – The observed data that the log probability will be computed of. Must have a shape that matches the shape of the predictor output.

  • likelihood (Callable[[jax.jaxlib._jax.Array | distreqx.distributions._distribution.AbstractDistribution], distreqx.distributions._distribution.AbstractDistribution]) – The likelihood function that takes the model prediction and returns the probability of observing some data. Can be a function or a PyTree with optional parameters. See pmrf.likelihoods for common likelihoods.

  • discrepancy (Callable[[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], distreqx.distributions._distribution.AbstractDistribution] | None) – An optional discrepancy model to cater for model misspecification. Can be a function or a PyTree with optional parameters. See pmrf.discrepancy_models for common discrepancy models.

  • use_orthogonal_discrepancy (bool) – Whether or not the discrepancy callable accepts a key-word argument “orthogonal_projection” which defines the model’s orthogonal sub-space. Used for gaussian processes.

  • event_transform (distreqx.bijectors._bijector.AbstractBijector) – A bijective transform that maps from “observation space” (predicted features) to “event space” (probability). Can be a bijector or None to use the default mapping (frequency as the event axis and independant real/imag).

  • event_ndims (int) – The number of trailing event dimensions in event space to use as the event shape. Defaults to 1.

__call__(model: Model, frequency: Frequency, **kwargs) Array

Evaluate the model response over the specified frequency range.

Parameters:
  • model (Model) – The model instance to evaluate.

  • freq (Frequency) – The frequency object defining the evaluation points.

  • **kwargs (dict) – Additional keyword arguments for the evaluation process.

Returns:

The evaluated model response.

Return type:

jnp.ndarray

predictive_distribution(model: Model, frequency: Frequency, **kwargs) AbstractDistribution

Returns the full predictive distribution of an observed event for a given model.

The returned distribution is in event space. To draw a sample from this distribution in observation space, see MarginalLogLikelihood.sample_observation().

sample_observation(key: Array, model: Model, frequency: Frequency, **kwargs) Array

Returns a sample from the predictive distribution in observation space.

discrepancy: Callable[[Array, Array], AbstractDistribution] | None = None

The optional discrepancy model.

event_ndims: int = 1

The number of trailing event dimensions.

event_transform: AbstractBijector = None

The bijective event transform.

likelihood: Callable[[Array | AbstractDistribution], AbstractDistribution]

The active likelihood function.

observed: ndarray

The observed data.

predictor: Callable[[Model, Frequency], Array]

The active predictor instance.

use_orthogonal_discrepancy: bool = False

Flag for orthogonal discrepancy.