Probability
parax.probability.AbstractProbabilistic
Bases: AbstractConstrained[T]
The abstract interface for a probabilistic PyTree.
Probabilistic PyTrees have a probability distribution
associated with them. That is, the event shape of the
distribution matches the PyTree structure of self.
Used as a type check for parax.is_probabilistic.
Attributes:
| Name | Type | Description |
|---|---|---|
distribution |
AbstractVar[AbstractDistribution]
|
The probability distribution associated with this PyTree node. |
constraint |
AbstractVar[AbstractConstraint]
|
Returns the active constraint of the PyTree. |
bounds |
AbstractVar[tuple[T, T]]
|
Returns the current PyTree bounds. Each must have a matching PyTree structure as |
parax.probability.tree_distributions(tree)
Extracts the individual probability distributions of a PyTree.
Standard arrays default to distreqx.distributions.ImproperUniform.
Note that this function does not allow non-array/probabilistic leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.probability.AbstractProbabilistic, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree containing probabilistic nodes or standard arrays. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree of the exact same structure containing the extracted |
PyTree
|
probability distributions. |
Source code in parax/probability.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 | |
parax.probability.tree_joint_distribution(tree)
Extracts the single joint probability distributions of a PyTree.
Wraps the output of parax.probability.tree_distributions
in a distreqx.distributions.Joint distribution to define
a single distribution that matches the shape of tree.
Note that this distribution is defined over the constrained space
of any parax.probability.AbstractProbabilistic variables.
For interoperability with unconstrained algorithms, see
parax.probability.tree_unconstrained_distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model containing probabilistic nodes or standard arrays. |
required |
Returns:
| Type | Description |
|---|---|
AbstractDistribution
|
A single joint distribution whose event shape matches the structure of |
Source code in parax/probability.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | |
parax.probability.is_dynamic(x)
Identifies parameters that should be updated during probabilistic inference.
This function acts as the primary filter for eqx.partition, determining
which nodes are routed to the dynamic (differentiable/optimizable) tree
and which are left behind in the static tree.
Because parax.probability.is_leaf protects unwrappable nodes from being
split open, this function captures those nodes completely whole, allowing
them to be safely unwrapped after partitioning. Therefore, if you would
like to pass the full, wrapped nodes through a jit boundary, you should
include additional conditions or partitioning steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Any leaf node in the PyTree (as defined by |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
True if the node is meant for the inference engine. Matches: 1. Standard JAX inexact arrays (floating-point tensors). 2. Entire unwrappable probabilistic nodes. |
|
Note |
Explicitly returns False for |
|
|
them into the static tree. |
Source code in parax/probability.py
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 | |
parax.probability.is_leaf(x)
Defines the tree traversal boundaries for probabilistic partitioning.
In the Parax ecosystem, certain custom nodes (like unwrappable probabilistic priors or posteriors) contain internal metadata. If Equinox traverses inside these nodes, it will strip their differentiable arrays away from their metadata, causing structural mismatches during recombination.
This function tells JAX/Equinox to treat these specific Parax objects as opaque, indivisible leaves.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Any node encountered during PyTree traversal. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
bool |
True if the node should NOT be traversed into. Matches: 1. Unwrappable probabilistic nodes (preserves their wrapper structure). 2. Constant nodes (protects static configuration objects). |
Source code in parax/probability.py
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | |