Probability
parax.probability.AbstractProbabilistic
Bases: AbstractConstrained[T]
The abstract interface for a probabilistic PyTree.
Probabilistic PyTrees have a probability distribution
associated with them. That is, the event shape of the
distribution matches the PyTree structure of self.
Used as a type check for parax.is_probabilistic.
Attributes:
| Name | Type | Description |
|---|---|---|
distribution |
AbstractVar[AbstractDistribution]
|
The probability distribution associated with this PyTree node. |
constraint |
AbstractVar[AbstractConstraint]
|
Returns the active constraint of the PyTree. |
bounds |
AbstractVar[tuple[T, T]]
|
Returns the current PyTree bounds. Each must have a matching PyTree structure as |
parax.probability.tree_distributions(tree)
Extracts the individual probability distributions of a PyTree.
Standard arrays default to distreqx.distributions.ImproperUniform.
Note that this function does not allow non-array/probabilistic leaf nodes.
If you have leaves in your tree that are neither arrays nor derive
from parax.probability.AbstractProbabilistic, be sure to mark
them as static or filter them out using e.g. eqx.filter first.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree containing probabilistic nodes or standard arrays. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree of the exact same structure containing the extracted |
PyTree
|
probability distributions. |
Source code in parax/probability.py
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | |
parax.probability.tree_joint_distribution(tree)
Extracts the single joint probability distributions of a PyTree.
Wraps the output of parax.probability.tree_distributions
in a distreqx.distributions.Joint distribution to define
a single distribution that matches the shape of tree.
Note that this distribution is defined over the constrained space
of any parax.probability.AbstractProbabilistic variables.
For interoperability with unconstrained algorithms, see
parax.probability.tree_unconstrained_distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model containing probabilistic nodes or standard arrays. |
required |
Returns:
| Type | Description |
|---|---|
AbstractDistribution
|
A single joint distribution whose event shape matches the structure of |
Source code in parax/probability.py
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | |