HMC

class pmrf.infer.HMC(num_warmup: int = 1000, target_acceptance_rate: float = 0.8, num_integration_steps: int = 30)

Bases: AbstractJointSampler

Hamiltonian Monte Carlo (HMC) in JAX.

Wrapper around blackjax.hmc.

Requires a static number of integration steps. Automatically adapts the step size and mass matrix.

Parameters:
  • num_warmup (int, default=1000) – Number of warmup steps for window adaptation.

  • target_acceptance_rate (float, default=0.8) – Target acceptance rate for step size adaptation.

  • num_integration_steps (int, default=30) – Number of integration steps per transition.

run(logposterior_fn: Callable[[PyTree, Any], Any], y0: PyTree, args: PyTree[Any], key: Array, init_samples: PyTree | None = None, max_steps: int | None = 1000, **kwargs) tuple[SampleResult, PyTree]

Execute the sampling algorithm.

Parameters:
  • logposterior_fn (callable) – A function taking the parameters and args as input and returning the log-posterior probability.

  • y0 (PyTree) – The initial parameters, either for shape reference or as a starting point.

  • args (Any) – Args to pass to fn.

  • key (Array) – A random JAX key.

  • init_samples (PyTree, optional) – An optional batched PyTree the same structure as y0 with initial samples to warm-start the algorithm.

  • max_steps (int, optional) – The maximum number of sampling steps to take. If None, implies there should be no limit.

  • **kwargs – Runtime arguments forward to the solver backend.

Returns:

A tuple of (pmrf.infer.SampleResult, metrics)`.

Return type:

tuple

num_integration_steps: int = 30
num_warmup: int = 1000
target_acceptance_rate: float = 0.8