"""
Conditioning a model on data using Bayesian inference.
"""
from typing import Callable
from functools import partial
import jax.numpy as jnp
import skrf
import distreqx.distributions as dist
import parax as prx
from inferix import PolyChord
from pmrf.core import Model, Frequency
from pmrf.math import CONVERSION_LOOKUP, LOSS_LOOKUP
from pmrf.constants import Inferer
from pmrf.network_collection import NetworkCollection
from pmrf.models import Measured
from pmrf.evaluators import FeatureAlias, DataLikelihood
from pmrf.likelihoods import ComplexGaussianLikelihood
from pmrf.infer.result import InferResult
from pmrf.infer.sample import sample
[docs]
def condition(
model: Model,
data: jnp.ndarray | skrf.Network | NetworkCollection,
frequency: Frequency | None = None,
solver: Inferer = PolyChord(),
*,
features: str | list[str] | Callable = 's',
likelihood_fn: Callable[[jnp.ndarray], dist.AbstractDistribution] = None,
**kwargs,
) -> InferResult:
"""
Conditions an RF model on measured data using Bayesian inference.
This high-level function handles data format coercion (e.g., extracting arrays
from scikit-rf Networks) and automatically composes the necessary evaluator metrics
to compute the log-likelihood over the parameter space.
Parameters
----------
model : Model
The RF model to fit.
data : jnp.ndarray | skrf.Network | NetworkCollection
The target data to fit against. Can be raw JAX arrays or standard Touchstone networks.
frequency : Frequency | None, default=None
The frequency sweep. Required if `data` is a raw array; otherwise automatically
extracted from the Network object.
solver : Solver, default=PolyChord()
The Bayesian sampling algorithm backend (e.g., PolyChord, MultiNest).
features : EvaluatorLike, default='s'
The specific circuit feature(s) to compute the likelihood against.
Usually passed as a tuple of real and imaginary parts for Bayesian analysis.
likelihood_fn : Callable[[jnp.ndarray], dist.AbstractDistribution], optional
The likelihood function that accepts a model prediction and returns a distribution
representing the probability of observing data given that prediction.
Can be a function or a callable PyTree. See :mod:``pmrf.likelihoods`` for common likelihoods.
Defaults to `None`, in which case :class:``pmrf.likelihoods.ComplexGaussianLikelihood``
is constructed internally with a symmetric noise model.
**kwargs : dict
Additional keyword arguments passed to the underlying solver.
Returns
-------
InferenceResult
The result containing the model loaded with empirical posterior distributions.
"""
# Error checking
if isinstance(data, jnp.ndarray) and frequency is None:
raise ValueError("Frequency must be passed if Network data is not provided")
# Resolve the features and data
if not isinstance(features, Callable):
features = FeatureAlias(features)
if isinstance(data, skrf.Network | NetworkCollection):
if frequency is None:
if isinstance(data, skrf.Network):
frequency = Frequency.from_skrf(data.frequency)
else:
frequency = Frequency.from_skrf(data.common_frequency())
target = features(Measured(data), frequency)
else:
target = data
# Resolve the likelihood model
if likelihood_fn is None:
likelihood_fn = ComplexGaussianLikelihood(sigma=prx.Uniform(0.0, 100.0, scale=1e-3))
log_likelihood_fn = DataLikelihood(predictor=features, data=target, likelihood=likelihood_fn)
# Run the sampling
return sample(log_likelihood_fn, model, frequency, solver=solver, **kwargs)