Uniform

class pmrf.distributions.Uniform(low: Float[jaxlib._jax.Array, '...'], high: Float[jaxlib._jax.Array, '...'])

Bases: AbstractSTDDistribution, AbstractSurvivalDistribution

Uniform distribution with low and high parameters.

cdf(value: Array) Array

See Distribution.cdf.

entropy() Array

See Distribution.entropy.

icdf(value: Array) Array

See Distribution.icdf.

kl_divergence(other_dist, **kwargs) Array

Calculates the KL divergence to another distribution.

Arguments:

  • other_dist: A compatible disteqx distribution.

  • kwargs: Additional kwargs.

Returns:

The KL divergence KL(self || other_dist).

log_cdf(value: Array) Array

See Distribution.log_cdf.

log_prob(value: Array) Array

See Distribution.log_prob.

mean() Array

See Distribution.mean.

median() Array

See Distribution.median.

mode() Array

See Distribution.probs.

prob(value: Array) Array

See Distribution.prob.

sample(key: Key[jaxlib._jax.Array, '']) Array

See Distribution.sample.

sample_and_log_prob(key: Key[jaxlib._jax.Array, '']) tuple[Array, Array]

Returns sample and its log prob.

By default, it just calls log_prob on the generated samples. However, for many distributions it’s more efficient to compute the log prob of samples than of arbitrary events (for example, there’s no need to check that a sample is within the distribution’s domain). If that’s the case, a subclass may override this method with a more efficient implementation.

Arguments:

  • key: PRNG key.

Returns:

  • A tuple of a sample and their log probs.

variance() Array

See Distribution.variance.

property event_shape: tuple[int, ...]

Shape of event of distribution samples.

high: Float[jaxlib._jax.Array, '...']
low: Float[jaxlib._jax.Array, '...']
property range: Array