Probabilistic
parax.probabilistic.AbstractProbabilistic
Bases: Module, Generic[Base]
The abstract interface for a probabilistic PyTree.
Probabilistic PyTrees have a probability distribution
associated with them. That is, samples from the resultant
distribution should match the PyTree structure of self.
Makes use of the concept of a "base" space where inference algorithms operate.
Used as a type check for parax.is_probabilistic.
Attributes:
| Name | Type | Description |
|---|---|---|
distribution |
AbstractVar[AbstractDistribution]
|
The probability distribution associated with this PyTree node. |
base
abstractmethod
property
Returns the current PyTree in the probability base space.
update(base)
abstractmethod
Returns a new instance of this object updated with a new base PyTree.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
Base
|
The new base-space PyTree representing the sampled state. |
required |
Returns:
| Type | Description |
|---|---|
AbstractProbabilistic
|
A new instance of the probabilistic object, updated to reflect the new base. |
Source code in parax/probabilistic.py
42 43 44 45 46 47 48 49 50 51 52 53 | |
parax.probabilistic.tree_distribution(tree)
Extracts the probability distributions of a PyTree.
Standard arrays default to distreqx.ImproperUniform.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model containing probabilistic nodes or standard arrays. |
required |
Returns:
| Type | Description |
|---|---|
PyTree
|
A PyTree of the exact same structure containing the extracted |
PyTree
|
probability distributions. |
Source code in parax/probabilistic.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | |
parax.probabilistic.tree_joint(tree)
Extracts the single joint probability distributions of a PyTree.
Wraps the output of parax.probabilistic.tree_distribution
in a distreqx.distributions.Joint distribution to define
a single distribution that matches the shape of tree.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tree
|
PyTree
|
The PyTree model containing probabilistic nodes or standard arrays. |
required |
Returns:
| Type | Description |
|---|---|
Joint
|
A single joint distribution whose samples match the structure of |
Source code in parax/probabilistic.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | |