Source code for pmrf.explore.sample

from typing import Callable

import jax
import jax.numpy as jnp
import jax.random as jr

from pmrf.core import Model, Frequency, EvaluatorLike
from pmrf.evaluators import FeatureAlias
from pmrf.explore.samplers import AbstractSampler
from pmrf.explore.result import ExploreResult


[docs] def sample( model: Model, sampler: AbstractSampler, *, max_samples: int | None = None, frequency: Frequency | None = None, features: Callable | None = 's', key: jax.Array | None = None, **kwargs ) -> ExploreResult: """ Explore the parameter space of a model using a specified sampling engine. This unified router executes the sampling algorithm using a standardized state-machine loop (init, step, terminate), supporting both one-shot and adaptive active learning strategies frictionlessly. Parameters ---------- model : Model The parametric model to sample. sampler : AbstractSampler The sampling algorithm to use. max_samples : int | None, default=None The maximum number of samples to generate. For one-shot samplers, this is the exact number generated. For adaptive samplers, this acts as a computational budget. If None, adaptive samplers run until convergence. frequency : Frequency | None, default=None The frequency sweep for feature evaluation. features : EvaluatorLike | None, default=None The specific circuit features to extract. key : jax.Array | None, default=None JAX PRNG key for stochastic samplers. **kwargs Additional arguments passed to the underlying evaluators. Returns ------- ExploreResult The comprehensive result object containing the original continuous model and batched execution states. """ if key is None: key = jr.PRNGKey(0) if not isinstance(features, Callable): features = FeatureAlias(features) d = model.num_flat_params def eval_fn(U_batch: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """Closure to map hypercube proposals to physical params and evaluate.""" return _evaluate_batch(model, U_batch, frequency, features, **kwargs) # 1. Initialize the solver state options = {"max_samples": max_samples} key, init_key = jr.split(key) state = sampler.init(eval_fn, d, init_key, options) # 2. Standardized Execution Loop while not sampler.terminate(state, max_samples): key, step_key = jr.split(key) state = sampler.step(eval_fn, d, state, step_key, options) # 3. Truncate if batching pushed us slightly over the budget thetas = state.sampled_params extracted_features = state.sampled_features if max_samples is not None and len(thetas) > max_samples: thetas = thetas[:max_samples] extracted_features = extracted_features[:max_samples] # 4. Package the array of parameters back into a cleanly batched JAX PyTree batched_models = jax.vmap(model.with_params)(thetas) return ExploreResult( model=model, # Leave the original continuous model untouched frequency=frequency, sampled_models=batched_models, sampled_features=extracted_features, history=state.backend_state )
def _evaluate_batch(model: Model, U: jnp.ndarray, frequency: Frequency | None, features: Callable, **kwargs) -> tuple[jnp.ndarray, jnp.ndarray]: """Helper to map hypercube proposals to physical params and evaluate features.""" def eval_single(u): flat_params = model.flat_params() theta = jnp.array([p.distribution.icdf(u_i) for p, u_i in zip(flat_params, u)]) m_sampled = model.with_params(theta) feat_val = features(m_sampled, frequency, **kwargs) if frequency else features(m_sampled, **kwargs) return theta, feat_val return jax.vmap(eval_single)(U)