Skip to content

Probabilistic

parax.probabilistic.AbstractProbabilistic

Bases: Module, Generic[Base]

The abstract interface for a probabilistic PyTree.

Probabilistic PyTrees have a probability distribution associated with them. That is, samples from the resultant distribution should match the PyTree structure of self.

Makes use of the concept of a "base" space where inference algorithms operate.

Used as a type check for parax.is_probabilistic.

Attributes:

Name Type Description
distribution AbstractVar[AbstractDistribution]

The probability distribution associated with this PyTree node.

base abstractmethod property

Returns the current PyTree in the probability base space.

update(base) abstractmethod

Returns a new instance of this object updated with a new base PyTree.

Parameters:

Name Type Description Default
base Base

The new base-space PyTree representing the sampled state.

required

Returns:

Type Description
AbstractProbabilistic

A new instance of the probabilistic object, updated to reflect the new base.

Source code in parax/probabilistic.py
42
43
44
45
46
47
48
49
50
51
52
53
@abstractmethod
def update(self, base: Base) -> "AbstractProbabilistic":
    """
    Returns a new instance of this object updated with a new base PyTree.

    Args:
        base: The new base-space PyTree representing the sampled state.

    Returns:
        A new instance of the probabilistic object, updated to reflect the new base.
    """
    pass

parax.probabilistic.tree_distribution(tree)

Extracts the probability distributions of a PyTree.

Standard arrays default to distreqx.ImproperUniform.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model containing probabilistic nodes or standard arrays.

required

Returns:

Type Description
PyTree

A PyTree of the exact same structure containing the extracted

PyTree

probability distributions.

Source code in parax/probabilistic.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def tree_distribution(tree: PyTree) -> PyTree:
    """
    Extracts the probability distributions of a PyTree.

    Standard arrays default to `distreqx.ImproperUniform`.

    Args:
        tree: The PyTree model containing probabilistic nodes or standard arrays.

    Returns:
        A PyTree of the exact same structure containing the extracted 
        probability distributions.
    """
    from parax.filters import is_probabilistic

    def _get_distribution(x):
        if is_probabilistic(x):
            return unwrap(x.distribution)
        if eqx.is_inexact_array(x):
            from distreqx.distributions import ImproperUniform
            return ImproperUniform(shape=jnp.shape(x))
        return x

    distributions = jax.tree_util.tree_map(_get_distribution, tree, is_leaf=is_probabilistic)
    return distributions

parax.probabilistic.tree_joint(tree)

Extracts the single joint probability distributions of a PyTree.

Wraps the output of parax.probabilistic.tree_distribution in a distreqx.distributions.Joint distribution to define a single distribution that matches the shape of tree.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model containing probabilistic nodes or standard arrays.

required

Returns:

Type Description
Joint

A single joint distribution whose samples match the structure of tree.

Source code in parax/probabilistic.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def tree_joint(tree: PyTree) -> Joint:
    """
    Extracts the single joint probability distributions of a PyTree.

    Wraps the output of `parax.probabilistic.tree_distribution`
    in a `distreqx.distributions.Joint` distribution to define
    a single distribution that matches the shape of `tree`.

    Args:
        tree: The PyTree model containing probabilistic nodes or standard arrays.

    Returns:
        A single joint distribution whose samples match the structure of `tree`.
    """
    return Joint(tree_distribution(tree))