AbstractDistribution

class pmrf.distributions.AbstractDistribution

Bases: AbstractStrictModule

Base class for all distreqx distributions.

abstractmethod cdf(value: PyTree[jax.jaxlib._jax.Array]) PyTree[jax.jaxlib._jax.Array]

Evaluates the cumulative distribution function at value.

Arguments:

  • value: An event.

Returns:

  • The CDF evaluated at value, i.e. P[X <= value].

cross_entropy(other_dist, **kwargs) Array

Calculates the cross entropy to another distribution.

Arguments:

  • other_dist: A compatible distreqx Distribution.

  • kwargs: Additional kwargs.

Returns:

  • The cross entropy H(self || other_dist).

abstractmethod entropy() PyTree[jax.jaxlib._jax.Array]

Calculates the Shannon entropy (in nats).

abstractmethod icdf(value: PyTree[jax.jaxlib._jax.Array]) PyTree[jax.jaxlib._jax.Array]

Evaluates the inverse cumulative distribution function at value.

For a given probability u, returns the value x such that P[X <= x] = u.

Arguments:

  • value: A probability value in [0, 1].

Returns:

  • The ICDF evaluated at value, i.e. x such that CDF(x) = value.

abstractmethod kl_divergence(other_dist, **kwargs) PyTree[jax.jaxlib._jax.Array]

Calculates the KL divergence to another distribution.

Arguments:

  • other_dist: A compatible distreqx Distribution.

  • kwargs: Additional kwargs.

Returns:

  • The KL divergence KL(self || other_dist).

abstractmethod log_cdf(value: PyTree[jax.jaxlib._jax.Array]) PyTree[jax.jaxlib._jax.Array]

Evaluates the log cumulative distribution function at value i.e. log P[X <= value].

abstractmethod log_prob(value: PyTree[jax.jaxlib._jax.Array]) PyTree[jax.jaxlib._jax.Array]

Calculates the log probability of an event.

Arguments:

  • value: An event.

Returns:

  • The log probability log P(value).

abstractmethod log_survival_function(value: PyTree[jax.jaxlib._jax.Array]) PyTree[jax.jaxlib._jax.Array]

Evaluates the log of the survival function at value.

Note that by default we use a numerically not necessarily stable definition of the log of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.

Arguments:

  • value: An event.

Returns:

  • The log of the survival function evaluated at value, i.e.

    log P[X > value]

abstractmethod mean() PyTree[jax.jaxlib._jax.Array]

Calculates the mean.

abstractmethod median() PyTree[jax.jaxlib._jax.Array]

Calculates the median.

abstractmethod mode() PyTree[jax.jaxlib._jax.Array]

Calculates the mode.

abstractmethod prob(value: PyTree[jax.jaxlib._jax.Array]) PyTree[jax.jaxlib._jax.Array]

Calculates the probability of an event.

Arguments:

  • value: An event.

Returns:

  • The probability P(value).

abstractmethod sample(key: Key[jaxlib._jax.Array, '']) PyTree[jax.jaxlib._jax.Array]

Samples an event.

abstractmethod sample_and_log_prob(key: Key[jaxlib._jax.Array, '']) tuple[PyTree[jax.jaxlib._jax.Array], PyTree[jax.jaxlib._jax.Array]]

Returns sample and its log prob.

By default, it just calls log_prob on the generated samples. However, for many distributions it’s more efficient to compute the log prob of samples than of arbitrary events (for example, there’s no need to check that a sample is within the distribution’s domain). If that’s the case, a subclass may override this method with a more efficient implementation.

Arguments:

  • key: PRNG key.

Returns:

  • A tuple of a sample and their log probs.

abstractmethod stddev() PyTree[jax.jaxlib._jax.Array]

Calculates the standard deviation.

abstractmethod survival_function(value: PyTree[jax.jaxlib._jax.Array]) PyTree[jax.jaxlib._jax.Array]

Evaluates the survival function at value.

Note that by default we use a numerically not necessarily stable definition of the survival function in terms of the CDF. More stable definitions should be implemented in subclasses for distributions for which they exist.

Arguments:

  • value: An event.

Returns:

  • The survival function evaluated at value, i.e. P[X > value]

abstractmethod variance() PyTree[jax.jaxlib._jax.Array]

Calculates the variance.

property dtype: dtype

Data type of a sample

abstract property event_shape: tuple[int] | PyTree[jax.ShapeDtypeStruct]

Shape of event of distribution samples.

property name: str

Distribution name.