Gamma

class pmrf.distributions.Gamma(concentration: float | Float[jaxlib._jax.Array, '...'], rate: float | Float[jaxlib._jax.Array, '...'])

Bases: AbstractSampleLogProbDistribution, AbstractProbDistribution, AbstractSurvivalDistribution

Gamma distribution with parameters concentration and rate.

The PDF of a Gamma distributed random variable \(X\) is defined on the interval \(X > 0\) and has the form:

$$p(x; alpha, beta) = frac{beta^alpha}{Gamma(alpha)}

x^{alpha - 1} e^{-beta x}$$

where \(\alpha > 0\) is the concentration (shape) parameter and \(\beta > 0\) is the rate (inverse scale) parameter.

Initializes a Gamma distribution.

Arguments:

  • concentration: Concentration (shape) parameter. Must be positive.

  • rate: Rate (inverse scale) parameter. Must be positive.

cdf(value: Array) Array

See Distribution.cdf.

entropy() Array

Calculates the Shannon entropy (in nats).

icdf(value: Array) Array

See Distribution.icdf.

kl_divergence(other_dist, *unused_args, **unused_kwargs) Array

KL divergence KL(self || other_dist) between two Gamma distributions.

Arguments:

  • other_dist: A Gamma distribution.

Returns:

  • KL(self || other_dist).

log_cdf(value: Array) Array

See Distribution.log_cdf.

log_prob(value: Array) Array

See Distribution.log_prob.

mean() Array

Calculates the mean.

median()

Calculates the median.

mode() Array

Calculates the mode.

sample(key: Key[jaxlib._jax.Array, '']) Array

See Distribution.sample.

stddev() Array

Calculates the standard deviation.

variance() Array

Calculates the variance.

concentration: Float[jaxlib._jax.Array, '...']
property event_shape: tuple

Shape of event of distribution samples.

rate: Float[jaxlib._jax.Array, '...']