Source code for pmrf.explore.samplers

import abc
from typing import Any, Callable

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx

from pmrf.utils.random import lhs_sample


class SamplerState(eqx.Module):
    """
    Standardized state object passed between iteration steps.
    """
    sampled_params: jnp.ndarray
    sampled_features: jnp.ndarray
    num_samples: int
    # Optional field for backends to store surrogate models or convergence metrics
    backend_state: Any = None 


[docs] class AbstractSampler(eqx.Module): """ Abstract base class for all exploration engines, mirroring the Optimistix AbstractIterativeSolver paradigm. """
[docs] @abc.abstractmethod def init( self, eval_fn: Callable[[jnp.ndarray], jnp.ndarray], d: int, key: jax.Array, options: dict[str, Any] ) -> SamplerState: """Initialize the sampler state and perform the first evaluations.""" pass
[docs] @abc.abstractmethod def step( self, eval_fn: Callable[[jnp.ndarray], jnp.ndarray], d: int, state: SamplerState, key: jax.Array, options: dict[str, Any] ) -> SamplerState: """Perform one iteration of the sampling algorithm.""" pass
[docs] @abc.abstractmethod def terminate(self, state: SamplerState, target_N: int) -> bool: """Determine whether the sampling loop should stop.""" pass
# -------------------------------------------------------------------------- # One-Shot Samplers # --------------------------------------------------------------------------
[docs] class AbstractOneShotSampler(AbstractSampler): """ One-shot samplers generate all points immediately during `init`. `step` is effectively a no-op, and `terminate` always returns True. """
[docs] @abc.abstractmethod def generate(self, N: int, d: int, key: jax.Array) -> jnp.ndarray: """Propose N points within the unit hypercube.""" pass
[docs] def init(self, eval_fn, d, key, options) -> SamplerState: # For one-shot, we assume target_N is passed in options N = options.get("target_N", 100) U = self.generate(N, d, key) thetas, features = eval_fn(U) return SamplerState( sampled_params=thetas, sampled_features=features, num_samples=N )
[docs] def step(self, eval_fn, d, state, key, options) -> SamplerState: return state
[docs] def terminate(self, state, target_N) -> bool: return True
[docs] class LatinHypercube(AbstractOneShotSampler): """Sampler using Latin Hypercube Sampling (LHS)."""
[docs] def generate(self, N: int, d: int, key: jax.Array) -> jnp.ndarray: return lhs_sample(N, d, key)
[docs] class Uniform(AbstractOneShotSampler): """Sampler using uniform random sampling."""
[docs] def generate(self, N: int, d: int, key: jax.Array) -> jnp.ndarray: return jax.random.uniform(key, shape=(N, d))
# -------------------------------------------------------------------------- # Adaptive (Active Learning) Samplers # --------------------------------------------------------------------------
[docs] class AbstractAdaptiveSampler(AbstractSampler): """ Adaptive samplers iterate until a budget is reached or convergence is met. """ initial_models: int = 10 batch_size: int = 1
[docs] def init(self, eval_fn, d, key, options) -> SamplerState: # Initialize with LHS U_init = lhs_sample(self.initial_models, d, key) thetas, features = eval_fn(U_init) return SamplerState( sampled_params=thetas, sampled_features=features, num_samples=self.initial_models )
[docs] def terminate(self, state: SamplerState, target_N: int) -> bool: # Stop if we hit the sample budget return state.num_samples >= target_N
[docs] class FieldSampler(AbstractAdaptiveSampler): """Samples new points at the maxima of a learned scalar field.""" num_grid_per_dim: int = 1024 grid_sampler: AbstractOneShotSampler = LatinHypercube()
[docs] @abc.abstractmethod def train_field(self, params: jnp.ndarray, features: jnp.ndarray, key: jax.Array) -> Any: pass
[docs] @abc.abstractmethod def evaluate_field(self, field: Any, theta: jnp.ndarray, key: jax.Array) -> float: pass
[docs] def step(self, eval_fn, d, state, key, options) -> SamplerState: key, field_key, grid_key, eval_key = jr.split(key, 4) # 1. Train the field on current state field = self.train_field(state.sampled_params, state.sampled_features, field_key) # 2. Generate Candidate Grid (Hypercube) K = self.num_grid_per_dim * d U_grid = self.grid_sampler.generate(K, d, grid_key) # 3. Evaluate Field on Grid eval_keys = jr.split(eval_key, K) grid_field = jax.vmap(lambda u, k: self.evaluate_field(field, u, k))(U_grid, eval_keys) # 4. Greedy Diversity Selection to get new hypercube proposals U_next, _ = self._select_field_points(self.batch_size, U_grid, grid_field) # 5. Evaluate the physical model (using the black-box function) new_thetas, new_features = eval_fn(U_next) # 6. Update State return SamplerState( sampled_params=jnp.vstack((state.sampled_params, new_thetas)), sampled_features=jnp.vstack((state.sampled_features, new_features)), num_samples=state.num_samples + len(new_thetas), backend_state=field )
def _select_field_points(self, N: int, points: jnp.ndarray, values: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: """Selects N points iteratively using a penalized greedy strategy.""" if N >= len(values): return points, values if N == 1: best_idx = jnp.argmax(values) return points[best_idx][None, :], values[best_idx].reshape(1) p_min, p_max = jnp.min(points, axis=0), jnp.max(points, axis=0) p_range = jnp.where((p_max - p_min) == 0, 1.0, p_max - p_min) norm_points = (points - p_min) / p_range v_min, v_max = jnp.min(values), jnp.max(values) if v_max == v_min: return points[:N], values[:N] scores = (values - v_min) / (v_max - v_min) L = 0.1 * jnp.sqrt(points.shape[1]) selected_indices = [] for _ in range(N): best_idx = jnp.argmax(scores) selected_indices.append(best_idx) if len(selected_indices) < N: dist_sq = jnp.sum((norm_points - norm_points[best_idx])**2, axis=1) penalty = 1.0 - jnp.exp(-dist_sq / (2 * L**2)) scores = scores * penalty scores = scores.at[best_idx].set(-jnp.inf) idx_array = jnp.array(selected_indices) return points[idx_array], values[idx_array]
[docs] class EqxLearnUncertainty(FieldSampler): """Adaptive sampler targeting regions of high surrogate uncertainty.""" surrogate: Any = None fit_kwargs: dict = eqx.field(default_factory=dict)
[docs] def train_field(self, params: jnp.ndarray, features: jnp.ndarray, key: jax.Array) -> Any: from eqxlearn import fit # Flatten features for eqx-learn flat_features = features.reshape(features.shape[0], -1) fitted_model, _ = fit(self.surrogate, X=params, y=flat_features, key=key, **self.fit_kwargs) return fitted_model
[docs] def evaluate_field(self, field: Any, theta: jnp.ndarray, key: jax.Array) -> float: _y_mean, y_var = field(theta, return_var=True) rayleigh_factor = jnp.sqrt(jnp.pi) / 2.0 total_std = jnp.sqrt(y_var.real + y_var.imag) expected_mae = jnp.mean(rayleigh_factor * total_std) return 20 * jnp.log10(expected_mae)