Skip to content

Probability

parax.probability.AbstractProbabilistic

Bases: AbstractConstrained[T]

The abstract interface for a probabilistic PyTree.

Probabilistic PyTrees have a probability distribution associated with them. That is, the event shape of the distribution matches the PyTree structure of self.

Used as a type check for parax.is_probabilistic.

Attributes:

Name Type Description
distribution AbstractVar[AbstractDistribution]

The probability distribution associated with this PyTree node.

constraint AbstractVar[AbstractConstraint]

Returns the active constraint of the PyTree.

bounds AbstractVar[tuple[T, T]]

Returns the current PyTree bounds. Each must have a matching PyTree structure as self.

parax.probability.tree_distributions(tree)

Extracts the individual probability distributions of a PyTree.

Standard arrays default to distreqx.distributions.ImproperUniform.

Note that this function does not allow non-array/probabilistic leaf nodes. If you have leaves in your tree that are neither arrays nor derive from parax.probability.AbstractProbabilistic, be sure to mark them as static or filter them out using e.g. eqx.filter first.

Parameters:

Name Type Description Default
tree PyTree

The PyTree containing probabilistic nodes or standard arrays.

required

Returns:

Type Description
PyTree

A PyTree of the exact same structure containing the extracted

PyTree

probability distributions.

Source code in parax/probability.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def tree_distributions(tree: PyTree) -> PyTree:
    """
    Extracts the individual probability distributions of a PyTree.

    Standard arrays default to `distreqx.distributions.ImproperUniform`.

    Note that this function does not allow non-array/probabilistic leaf nodes.
    If you have leaves in your tree that are neither arrays nor derive
    from `parax.probability.AbstractProbabilistic`, be sure to mark
    them as static or filter them out using e.g. `eqx.filter` first.

    Args:
        tree: The PyTree containing probabilistic nodes or standard arrays.

    Returns:
        A PyTree of the exact same structure containing the extracted 
        probability distributions.
    """
    from distreqx.distributions import ImproperUniform
    from parax.wrappers import unwrap

    def _get_distribution(path, x):
        if is_probabilistic(x):
            return unwrap(x.distribution)
        if eqx.is_inexact_array(x):
            return ImproperUniform(shape=jnp.shape(x))
        raise ValueError(f"Found a leaf node of type {type(x)} that is neither probabilistic nor an array in `parax.probability.tree_distributions`. Value: {x}, path: {path}")

    distributions = jax.tree.map_with_path(_get_distribution, tree, is_leaf=is_probabilistic)
    return distributions

parax.probability.tree_joint_distribution(tree)

Extracts the single joint probability distributions of a PyTree.

Wraps the output of parax.probability.tree_distributions in a distreqx.distributions.Joint distribution to define a single distribution that matches the shape of tree.

Note that this distribution is defined over the constrained space of any parax.probability.AbstractProbabilistic variables. For interoperability with unconstrained algorithms, see parax.probability.tree_unconstrained_distribution.

Parameters:

Name Type Description Default
tree PyTree

The PyTree model containing probabilistic nodes or standard arrays.

required

Returns:

Type Description
AbstractDistribution

A single joint distribution whose event shape matches the structure of tree.

Source code in parax/probability.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def tree_joint_distribution(tree: PyTree) -> AbstractDistribution:
    """
    Extracts the single joint probability distributions of a PyTree.

    Wraps the output of `parax.probability.tree_distributions`
    in a `distreqx.distributions.Joint` distribution to define
    a single distribution that matches the shape of `tree`.

    Note that this distribution is defined over the constrained space
    of any `parax.probability.AbstractProbabilistic` variables.
    For interoperability with unconstrained algorithms, see
    `parax.probability.tree_unconstrained_distribution`.

    Args:
        tree: The PyTree model containing probabilistic nodes or standard arrays.

    Returns:
        A single joint distribution whose event shape matches the structure of `tree`.
    """
    from distreqx.distributions import Joint
    return Joint(tree_distributions(tree))

parax.probability.is_dynamic(x)

Identifies parameters that should be updated during probabilistic inference.

This function acts as the primary filter for eqx.partition, determining which nodes are routed to the dynamic (differentiable/optimizable) tree and which are left behind in the static tree.

Because parax.probability.is_leaf protects unwrappable nodes from being split open, this function captures those nodes completely whole, allowing them to be safely unwrapped after partitioning. Therefore, if you would like to pass the full, wrapped nodes through a jit boundary, you should include additional conditions or partitioning steps.

Parameters:

Name Type Description Default
x

Any leaf node in the PyTree (as defined by is_leaf).

required

Returns:

Name Type Description
bool

True if the node is meant for the inference engine. Matches: 1. Standard JAX inexact arrays (floating-point tensors). 2. Entire unwrappable probabilistic nodes.

Note

Explicitly returns False for parax.constant nodes, forcing

them into the static tree.

Source code in parax/probability.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
def is_dynamic(x):
    """
    Identifies parameters that should be updated during probabilistic inference.

    This function acts as the primary filter for `eqx.partition`, determining 
    which nodes are routed to the `dynamic` (differentiable/optimizable) tree 
    and which are left behind in the `static` tree.

    Because `parax.probability.is_leaf` protects unwrappable nodes from being 
    split open, this function captures those nodes completely whole, allowing 
    them to be safely unwrapped *after* partitioning. Therefore, if you would
    like to pass the full, wrapped nodes through a jit boundary, you should
    include additional conditions or partitioning steps.

    Args:
        x: Any leaf node in the PyTree (as defined by `is_leaf`).

    Returns:
        bool: True if the node is meant for the inference engine. Matches:
            1. Standard JAX inexact arrays (floating-point tensors).
            2. Entire unwrappable probabilistic nodes.
        Note: Explicitly returns False for `parax.constant` nodes, forcing 
        them into the static tree.
    """    
    from parax.constants import is_constant
    if is_constant(x): 
        return False
    if _is_unwrappable_probabilistic(x): 
        return True
    return eqx.is_inexact_array(x)

parax.probability.is_leaf(x)

Defines the tree traversal boundaries for probabilistic partitioning.

In the Parax ecosystem, certain custom nodes (like unwrappable probabilistic priors or posteriors) contain internal metadata. If Equinox traverses inside these nodes, it will strip their differentiable arrays away from their metadata, causing structural mismatches during recombination.

This function tells JAX/Equinox to treat these specific Parax objects as opaque, indivisible leaves.

Parameters:

Name Type Description Default
x

Any node encountered during PyTree traversal.

required

Returns:

Name Type Description
bool

True if the node should NOT be traversed into. Matches: 1. Unwrappable probabilistic nodes (preserves their wrapper structure). 2. Constant nodes (protects static configuration objects).

Source code in parax/probability.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def is_leaf(x):
    """
    Defines the tree traversal boundaries for probabilistic partitioning.

    In the Parax ecosystem, certain custom nodes (like unwrappable probabilistic 
    priors or posteriors) contain internal metadata. If Equinox traverses inside 
    these nodes, it will strip their differentiable arrays away from their metadata, 
    causing structural mismatches during recombination.

    This function tells JAX/Equinox to treat these specific Parax objects as 
    opaque, indivisible leaves. 

    Args:
        x: Any node encountered during PyTree traversal.

    Returns:
        bool: True if the node should NOT be traversed into. Matches:
            1. Unwrappable probabilistic nodes (preserves their wrapper structure).
            2. Constant nodes (protects static configuration objects).
    """
    from parax.constants import is_constant
    return _is_unwrappable_probabilistic(x) or is_constant(x)