NSS

class pmrf.infer.NSS(num_delete: int | None = None, num_inner_steps: int | None = None, evidence_convergence: float = 0.001, block_size: int | None = 100)

Bases: AbstractSplitSampler

(experimental) A Nested Slice Sampler (NSS) in JAX.

A wrapper around BlackJAX’s experimental NSS sampler. This requires a custom fork of BlackJAX, available via pip install git+https://github.com/handley-lab/blackjax.git@v0.1.0-beta.

Parameters:
  • num_delete (int, optional) – Number of live points to delete per step and therefore vectorize over. Defaults to 0.1 x num_live if not provided.

  • num_inner_steps (int) – The length of the short Markov chains used to update the live points. Defaults to 3 x dim if not provided.

  • evidence_convergence (float, default=1e-3) – Threshold for evidence convergence when max_steps is None.

  • block_size (int, optional) – The number of steps to execute on-device per block before checking convergence. Defaults to 100.

run(loglikelihood_fn: Callable[[PyTree, Any], Any], logprior_fn: Callable[[PyTree], Shaped[jaxlib._jax.Array, '']], y0: PyTree, args: PyTree[Any], key: Array, init_samples: PyTree = None, max_steps: int | None = None, **kwargs) tuple[SampleResult, PyTree]

Execute the sampling algorithm.

Parameters:
  • loglikelihood_fn (callable) – A function taking the parameters and args as input and returning the log-likelihood.

  • logprior_fn (callable) – A function taking the parameters and args as input and returning the log prior probability.

  • y0 (PyTree) – The initial parameters, either for shape reference or as a starting point.

  • args (Any) – Args to pass to fn.

  • key (Array) – A random JAX key.

  • init_samples (PyTree, optional) – An optional batched PyTree the same structure as y0 with initial samples to warm-start the algorithm.

  • max_steps (int, optional) – The maximum number of sampling steps to take. If None, implies there should be no limit.

  • **kwargs – Runtime arguments forward to the solver backend.

Returns:

A tuple of (pmrf.infer.SampleResult, metrics)`.

Return type:

tuple

block_size: int | None = 100
evidence_convergence: float = 0.001
num_delete: int | None = None
num_inner_steps: int | None = None