PolyChord

class pmrf.infer.PolyChord(nlive: int = -1, num_repeats: int | None = None, nprior: int = -1, nfail: int = -1, do_clustering: bool = True, feedback: int = 1, precision_criterion: float = 0.001, logzero: float = -1e+30, boost_posterior: float = 0.0, posteriors: bool = True, equals: bool = True, cluster_posteriors: bool = True, write_resume: bool = True, write_paramnames: bool = False, read_resume: bool = True, write_stats: bool = True, write_live: bool = True, write_dead: bool = True, write_prior: bool = True, maximise: bool = False, compression_factor: float = np.float64(0.36787944117144233), synchronous: bool = True, base_dir: str = 'chains', file_root: str = 'test', cluster_dir: str = 'clusters', seed: int = -1, nlives: Dict[float, int]=<factory>, paramnames: list[tuple[str, str]] | None=None)

Bases: AbstractHypercubeSampler

The PolyChord Nested Sampler wrapped in a JAX interface.

Acts as an adapter layer between JAX PyTrees and PolyChord’s required flat 1D NumPy arrays. It automatically handles flattening and unflattening of complex parameter structures, JIT-compiles the likelihood and prior transforms for performance, and bridges JAX operations with PolyChord’s host-based MPI sampling routines.

Parameters:
  • num_repeats (int | None) – The length of the slice sampling chain to generate a new live point. If None, dynamically defaults to 5 * ndims.

  • nprior (int) – The number of prior samples to draw before clustering begins (default: -1).

  • nfail (int) – The number of failed slice sampling steps before giving up (default: -1).

  • do_clustering (bool) – Whether to use k-means clustering to handle multimodal posteriors (default: True).

  • feedback (int) – The level of output written to stdout. 0=none, 1=standard, 2=detailed (default: 1).

  • precision_criterion (float) – The stopping criterion based on the estimated evidence precision (default: 1e-3).

  • logzero (float) – The numerical value used to represent log(0) (default: -1e30).

  • boost_posterior (float) – Boost the number of live points near the peak to improve posterior samples (default: 0.0).

  • posteriors (bool) – Whether to produce standard posterior output files (default: True).

  • equals (bool) – Whether to output equally weighted posterior samples (default: True).

  • cluster_posteriors (bool) – Whether to produce posterior output files for individual clusters (default: True).

  • write_resume (bool) – Whether to continuously write resume files during the run (default: True).

  • write_paramnames (bool) – Whether to generate a .paramnames file for post-processing tools (default: False).

  • read_resume (bool) – Whether to attempt resuming from a previous partially completed run (default: True).

  • write_stats (bool) – Whether to write run statistics to a .stats file (default: True).

  • write_live (bool) – Whether to dump the current live points to disk (default: True).

  • write_dead (bool) – Whether to record the dead points (the core nested sampling output) to disk (default: True).

  • write_prior (bool) – Whether to write prior samples to disk (default: True).

  • maximise (bool) – Whether to perform a maximization phase to find the exact MAP estimate (default: False).

  • compression_factor (float) – The compression factor used for slice sampling (default: np.exp(-1.0)).

  • synchronous (bool) – Whether to run MPI operations synchronously (default: True).

  • base_dir (str) – The base directory path where all output files will be saved (default: “chains”).

  • file_root (str) – The root naming convention for all generated output files (default: “test”).

  • cluster_dir (str) – The directory name for cluster-specific outputs (default: “clusters”).

  • seed (int) – Random seed for the sampler. Uses time if set to -1 (default: -1).

  • nlives (dict) – A dictionary mapping log-likelihood contours to the number of live points.

  • paramnames (list of tuple) – A list of parameter names and LaTeX formatted names, e.g., [(“p1”, r” heta_1”)]. Must match the flatten dimension of y0.

run(loglikelihood_fn: Callable[[PyTree, Any], Shaped[jaxlib._jax.Array, '']], prior_transform_fn: Callable[[PyTree, Any], PyTree], u0: PyTree, args: PyTree[Any], key: Array, init_cube_samples: PyTree | None = None, max_steps: int | None = None, **kwargs) tuple[SampleResult, PyTree]

Execute the sampling algorithm.

Parameters:
  • loglikelihood_fn (callable) – A function taking the physical parameters and args as input and returning the log-likelihood.

  • prior_transform_fn (callable) – A function taking the hypercube parameters and args as input and returning the physical parameters.

  • u0 (PyTree) – The initial parameters in the unit hypercube, either for shape reference or as a starting point.

  • args (Any) – Args to pass to fn.

  • key (Array) – A random JAX key.

  • init_cube_samples (PyTree, optional) – An optional batched PyTree the same structure as u0 with initial hypercube samples to warm-start the algorithm.

  • max_steps (int, optional) – The maximum number of sampling steps to take. If None, implies there should be no limit.

  • **kwargs – Runtime arguments forward to the solver backend.

Returns:

A tuple of (pmrf.infer.SampleResult, metrics)`.

Return type:

tuple

base_dir: str = 'chains'
boost_posterior: float = 0.0
cluster_dir: str = 'clusters'
cluster_posteriors: bool = True
compression_factor: float = np.float64(0.36787944117144233)
do_clustering: bool = True
equals: bool = True
feedback: int = 1
file_root: str = 'test'
logzero: float = -1e+30
maximise: bool = False
nfail: int = -1
nlive: int = -1
nlives: Dict[float, int]
nprior: int = -1
num_repeats: int | None = None
paramnames: list[tuple[str, str]] | None = None
posteriors: bool = True
precision_criterion: float = 0.001
read_resume: bool = True
property requires_hypercube
seed: int = -1
synchronous: bool = True
write_dead: bool = True
write_live: bool = True
write_paramnames: bool = False
write_prior: bool = True
write_resume: bool = True
write_stats: bool = True