NSS
- class pmrf.infer.NSS(num_delete: int | None = None, num_inner_steps: int | None = None, evidence_convergence: float = 0.001, block_size: int | None = 100)
Bases:
AbstractSplitSampler(experimental) A Nested Slice Sampler (NSS) in JAX.
A wrapper around BlackJAX’s experimental NSS sampler. This requires a custom fork of BlackJAX, available via pip install git+https://github.com/handley-lab/blackjax.git@v0.1.0-beta.
- Parameters:
num_delete (int, optional) – Number of live points to delete per step and therefore vectorize over. Defaults to 0.1 x num_live if not provided.
num_inner_steps (int) – The length of the short Markov chains used to update the live points. Defaults to 3 x dim if not provided.
evidence_convergence (float, default=1e-3) – Threshold for evidence convergence when max_steps is None.
block_size (int, optional) – The number of steps to execute on-device per block before checking convergence. Defaults to 100.
- run(loglikelihood_fn: Callable[[PyTree, Any], Any], logprior_fn: Callable[[PyTree], Shaped[jaxlib._jax.Array, '']], y0: PyTree, args: PyTree[Any], key: Array, init_samples: PyTree = None, max_steps: int | None = None, **kwargs) tuple[SampleResult, PyTree]
Execute the sampling algorithm.
- Parameters:
loglikelihood_fn (callable) – A function taking the parameters and args as input and returning the log-likelihood.
logprior_fn (callable) – A function taking the parameters and args as input and returning the log prior probability.
y0 (PyTree) – The initial parameters, either for shape reference or as a starting point.
args (Any) – Args to pass to fn.
key (Array) – A random JAX key.
init_samples (PyTree, optional) – An optional batched PyTree the same structure as y0 with initial samples to warm-start the algorithm.
max_steps (int, optional) – The maximum number of sampling steps to take. If None, implies there should be no limit.
**kwargs – Runtime arguments forward to the solver backend.
- Returns:
A tuple of (
pmrf.infer.SampleResult, metrics)`.- Return type:
tuple
- block_size: int | None = 100
- evidence_convergence: float = 0.001
- num_delete: int | None = None
- num_inner_steps: int | None = None