Skip to content

Probability

parax.probability.AbstractProbabilistic

Bases: AbstractConstrained[T]

The abstract interface for a probabilistic PyTree.

Probabilistic PyTrees have a probability distribution associated with them. That is, the event shape of the distribution matches the PyTree structure of self.

Used as a type check for parax.is_probabilistic.

Attributes:

Name Type Description
distribution AbstractVar[AbstractDistribution]

The probability distribution associated with this PyTree node.

constraint AbstractVar[AbstractConstraint]

Returns the active constraint of the PyTree.

bounds AbstractVar[tuple[T, T]]

Returns the current PyTree bounds. Each must have a matching PyTree structure as self.

parax.probability.tree_distributions(tree)

Extracts the individual probability distributions of a PyTree.

Standard arrays default to distreqx.distributions.ImproperUniform.

Note that this function does not allow non-array/probabilistic leaf nodes. If you have leaves in your tree that are neither arrays nor derive from parax.probability.AbstractProbabilistic, be sure to mark them as static or filter them out using e.g. eqx.filter first.

Parameters:

Name Type Description Default
tree PyTree

The PyTree containing probabilistic nodes or standard arrays.

required

Returns:

Type Description
PyTree

A PyTree of the exact same structure containing the extracted

PyTree

probability distributions.

Source code in parax/probability.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def tree_distributions(tree: PyTree) -> PyTree:
    """
    Extracts the individual probability distributions of a PyTree.

    Standard arrays default to `distreqx.distributions.ImproperUniform`.

    Note that this function does not allow non-array/probabilistic leaf nodes.
    If you have leaves in your tree that are neither arrays nor derive
    from `parax.probability.AbstractProbabilistic`, be sure to mark
    them as static or filter them out using e.g. `eqx.filter` first.

    Args:
        tree: The PyTree containing probabilistic nodes or standard arrays.

    Returns:
        A PyTree of the exact same structure containing the extracted 
        probability distributions.
    """
    from distreqx.distributions import ImproperUniform
    from parax.wrappers import unwrap

    def _get_distribution(path, x):
        if is_probabilistic(x):
            return unwrap(x.distribution)
        if eqx.is_inexact_array(x):
            return ImproperUniform(shape=jnp.shape(x))
        raise ValueError(f"Found a leaf node of type {type(x)} that is neither probabilistic nor an array in `parax.probability.tree_distributions`. Value: {x}, path: {path}")

    distributions = jax.tree.map_with_path(_get_distribution, tree, is_leaf=is_probabilistic)
    return distributions

parax.probability.tree_joint_distribution(tree)

Extracts the single joint probability distributions of a PyTree.

Wraps the output of parax.probability.tree_distributions in a distreqx.distributions.Joint distribution to define a single distribution that matches the shape of tree.

Note that this distribution is defined over the constrained space of any parax.probability.AbstractProbabilistic variables. For interoperability with unconstrained algorithms, see parax.probability.tree_unconstrained_distribution.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model containing probabilistic nodes or standard arrays.

required

Returns:

Type Description
AbstractDistribution

A single joint distribution whose event shape matches the structure of tree.

Source code in parax/probability.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def tree_joint_distribution(tree: PyTree) -> AbstractDistribution:
    """
    Extracts the single joint probability distributions of a PyTree.

    Wraps the output of `parax.probability.tree_distributions`
    in a `distreqx.distributions.Joint` distribution to define
    a single distribution that matches the shape of `tree`.

    Note that this distribution is defined over the constrained space
    of any `parax.probability.AbstractProbabilistic` variables.
    For interoperability with unconstrained algorithms, see
    `parax.probability.tree_unconstrained_distribution`.

    Args:
        tree: The PyTree model containing probabilistic nodes or standard arrays.

    Returns:
        A single joint distribution whose event shape matches the structure of `tree`.
    """
    from distreqx.distributions import Joint
    return Joint(tree_distributions(tree))