PolyChord
- class pmrf.infer.PolyChord(nlive: int = -1, num_repeats: int | None = None, nprior: int = -1, nfail: int = -1, do_clustering: bool = True, feedback: int = 1, precision_criterion: float = 0.001, logzero: float = -1e+30, boost_posterior: float = 0.0, posteriors: bool = True, equals: bool = True, cluster_posteriors: bool = True, write_resume: bool = True, write_paramnames: bool = False, read_resume: bool = True, write_stats: bool = True, write_live: bool = True, write_dead: bool = True, write_prior: bool = True, maximise: bool = False, compression_factor: float = np.float64(0.36787944117144233), synchronous: bool = True, base_dir: str = 'chains', file_root: str = 'test', cluster_dir: str = 'clusters', seed: int = -1, nlives: Dict[float, int]=<factory>, paramnames: list[tuple[str, str]] | None=None)
Bases:
AbstractHypercubeSamplerThe PolyChord Nested Sampler wrapped in a JAX interface.
Acts as an adapter layer between JAX PyTrees and PolyChord’s required flat 1D NumPy arrays. It automatically handles flattening and unflattening of complex parameter structures, JIT-compiles the likelihood and prior transforms for performance, and bridges JAX operations with PolyChord’s host-based MPI sampling routines.
- Parameters:
num_repeats (int | None) – The length of the slice sampling chain to generate a new live point. If None, dynamically defaults to 5 * ndims.
nprior (int) – The number of prior samples to draw before clustering begins (default: -1).
nfail (int) – The number of failed slice sampling steps before giving up (default: -1).
do_clustering (bool) – Whether to use k-means clustering to handle multimodal posteriors (default: True).
feedback (int) – The level of output written to stdout. 0=none, 1=standard, 2=detailed (default: 1).
precision_criterion (float) – The stopping criterion based on the estimated evidence precision (default: 1e-3).
logzero (float) – The numerical value used to represent log(0) (default: -1e30).
boost_posterior (float) – Boost the number of live points near the peak to improve posterior samples (default: 0.0).
posteriors (bool) – Whether to produce standard posterior output files (default: True).
equals (bool) – Whether to output equally weighted posterior samples (default: True).
cluster_posteriors (bool) – Whether to produce posterior output files for individual clusters (default: True).
write_resume (bool) – Whether to continuously write resume files during the run (default: True).
write_paramnames (bool) – Whether to generate a .paramnames file for post-processing tools (default: False).
read_resume (bool) – Whether to attempt resuming from a previous partially completed run (default: True).
write_stats (bool) – Whether to write run statistics to a .stats file (default: True).
write_live (bool) – Whether to dump the current live points to disk (default: True).
write_dead (bool) – Whether to record the dead points (the core nested sampling output) to disk (default: True).
write_prior (bool) – Whether to write prior samples to disk (default: True).
maximise (bool) – Whether to perform a maximization phase to find the exact MAP estimate (default: False).
compression_factor (float) – The compression factor used for slice sampling (default: np.exp(-1.0)).
synchronous (bool) – Whether to run MPI operations synchronously (default: True).
base_dir (str) – The base directory path where all output files will be saved (default: “chains”).
file_root (str) – The root naming convention for all generated output files (default: “test”).
cluster_dir (str) – The directory name for cluster-specific outputs (default: “clusters”).
seed (int) – Random seed for the sampler. Uses time if set to -1 (default: -1).
nlives (dict) – A dictionary mapping log-likelihood contours to the number of live points.
paramnames (list of tuple) – A list of parameter names and LaTeX formatted names, e.g., [(“p1”, r” heta_1”)]. Must match the flatten dimension of y0.
- run(loglikelihood_fn: Callable[[PyTree, Any], Shaped[jaxlib._jax.Array, '']], prior_transform_fn: Callable[[PyTree, Any], PyTree], u0: PyTree, args: PyTree[Any], key: Array, init_cube_samples: PyTree | None = None, max_steps: int | None = None, **kwargs) tuple[SampleResult, PyTree]
Execute the sampling algorithm.
- Parameters:
loglikelihood_fn (callable) – A function taking the physical parameters and args as input and returning the log-likelihood.
prior_transform_fn (callable) – A function taking the hypercube parameters and args as input and returning the physical parameters.
u0 (PyTree) – The initial parameters in the unit hypercube, either for shape reference or as a starting point.
args (Any) – Args to pass to fn.
key (Array) – A random JAX key.
init_cube_samples (PyTree, optional) – An optional batched PyTree the same structure as u0 with initial hypercube samples to warm-start the algorithm.
max_steps (int, optional) – The maximum number of sampling steps to take. If None, implies there should be no limit.
**kwargs – Runtime arguments forward to the solver backend.
- Returns:
A tuple of (
pmrf.infer.SampleResult, metrics)`.- Return type:
tuple
- base_dir: str = 'chains'
- boost_posterior: float = 0.0
- cluster_dir: str = 'clusters'
- cluster_posteriors: bool = True
- compression_factor: float = np.float64(0.36787944117144233)
- do_clustering: bool = True
- equals: bool = True
- feedback: int = 1
- file_root: str = 'test'
- logzero: float = -1e+30
- maximise: bool = False
- nfail: int = -1
- nlive: int = -1
- nlives: Dict[float, int]
- nprior: int = -1
- num_repeats: int | None = None
- paramnames: list[tuple[str, str]] | None = None
- posteriors: bool = True
- precision_criterion: float = 0.001
- read_resume: bool = True
- property requires_hypercube
- seed: int = -1
- synchronous: bool = True
- write_dead: bool = True
- write_live: bool = True
- write_paramnames: bool = False
- write_prior: bool = True
- write_resume: bool = True
- write_stats: bool = True