Source code for pmrf.infer.result

from typing import Callable

import parax as prx
import jax
import jax.numpy as jnp
import inferix as infx
from parax import partition
from jaxtyping import Array

from pmrf.core import Model, Frequency, Problem

[docs] class InferResult(prx.Module): """ Standardized return object for inference routines. Attributes ---------- model : Model The circuit model holding the finalized, optimized parameter state and posterior. likelihood : Evaluator The evaluator (e.g., :class:`pmrf.evaluators.Likelihood`) used to calculate the likelihood. sampled_models : Model The final batched model of sampled models. sampled_likelihoods : Evaluator A batched Likelihood containing all accepted sample states. log_likelihoods : jnp.ndarray The evaluated log-likelihoods for each sample. history : Any The underlying solution object returned by the solver. """ model: Model # The model updated with empirical posterior distributions likelihood: Callable # The likelihood evaluator used # Raw Sample Data sampled_models: Model # A batched Model containing all accepted sample states sampled_likelihoods: Array # A batched Likelihood containing all accepted sample states log_likelihoods: jnp.ndarray # The evaluated log-likelihoods for each sample weights: jnp.ndarray | None = None frequency: Frequency | None = None stats: infx.Result = None # Results/trace from the underlying nested sampler def _prepare_export_data(self, model_prefix: str, likelihood_prefix: str): """Helper method to extract, format, and check parameter data for export.""" import collections import numpy as np # 1. Cleanly format prefixes m_prefix = f"{model_prefix}_" if model_prefix else "" l_prefix = f"{likelihood_prefix}_" if likelihood_prefix else "" model_param_names = [f"{m_prefix}{name}" for name in self.model.flat_param_names()] likelihood_param_names = [f"{l_prefix}{name}" for name in self.likelihood.flat_param_names()] # 2. Perform Collision Check param_names = model_param_names + likelihood_param_names if len(param_names) != len(set(param_names)): duplicates = [item for item, count in collections.Counter(param_names).items() if count > 1] raise ValueError( f"Parameter name collision detected for: {duplicates}. " "Please provide a unique `model_prefix` and/or `likelihood_prefix` to resolve this." ) # 3. Partition out static variables to isolate batched dynamic parameters dynamic_models, _ = partition(self.sampled_models) dynamic_likelihoods, _ = partition(self.sampled_likelihoods) # 4. Flatten and vmap flatten_fn = lambda m: jax.flatten_util.ravel_pytree(m)[0] sampled_model_params = jax.vmap(flatten_fn)(dynamic_models) sampled_log_likelihood_params = jax.vmap(flatten_fn)(dynamic_likelihoods) # 5. Concatenate and cast to standard numpy sampled_params = np.asarray(jnp.hstack((sampled_model_params, sampled_log_likelihood_params))) return param_names, sampled_params
[docs] def combined_flat_param_values(self) -> jnp.ndarray: return self._prepare_export_data(model_prefix='model', likelihood_prefix='likelihood')[1]
[docs] def to_arviz(self, model_prefix='', likelihood_prefix=''): import numpy as np import arviz as az # 1. Get standardized names and numpy arrays param_names, sampled_params = self._prepare_export_data(model_prefix, likelihood_prefix) # 2. Construct the ArviZ posterior dictionary # ArviZ requires shape (n_chains, n_draws). We expand dimensions to add a dummy chain. posterior_dict = {} for i, name in enumerate(param_names): posterior_dict[name] = np.expand_dims(sampled_params[:, i], axis=0) # 3. Extract sample statistics sample_stats = { "log_likelihood": np.expand_dims(np.asarray(self.log_likelihoods), axis=0) } if self.weights is not None: sample_stats["weights"] = np.expand_dims(np.asarray(self.weights), axis=0) # 4. Build and return the InferenceData object return az.from_dict( posterior=posterior_dict, sample_stats=sample_stats )
[docs] def to_anesthetic(self, model_prefix='', likelihood_prefix='', logL_birth=None): import numpy as np import pandas as pd import anesthetic as an # 1. Get standardized names and numpy arrays param_names, sampled_params = self._prepare_export_data(model_prefix, likelihood_prefix) # 2. Build the core pandas DataFrame df = pd.DataFrame(sampled_params, columns=param_names) # 3. Extract sample statistics logL = np.asarray(self.log_likelihoods) weights = np.asarray(self.weights) if self.weights is not None else None # 4. Determine which Anesthetic object to build if logL_birth is not None: return an.NestedSamples( data=df, logL=logL, logL_birth=np.asarray(logL_birth), weights=weights ) else: return an.Samples( data=df, logL=logL, weights=weights )