Probabilistic
- class pmrf.models.composite.wrapped.Probabilistic(model: ~pmrf.models.base.Model, distribution: ~distreqx.distributions._distribution.AbstractDistribution, target: ~typing.Callable[[~typing.Any], ~typing.Any] = <function Probabilistic.<lambda>>, constraint: ~parax.constraints.AbstractConstraint | None = None)
Bases:
Model(experimental) A wrapper to make an existing model probabilistic.
This provides the ability to associate a probability distribution with a model or one of its sub-models/parameters after the model was create.
This is a useful for advanced use-cases where you want to attach a distribution to an entire model (perhaps overriding previous distributions on lower levels), as opposed to more standard cases where you want to model the distributions of individual variables (in which case you should likely use pmrf.Random instead).
- Variables:
probabilistic (Model | parax.Probabilize) – The updated structure containing the parax.Probabilize node.
Examples
>>> from pmrf.models import Probabilistic, Resistor >>> from pmrf.distributions import Normal, Joint >>> >>> res = Resistor(R=50.0) >>> >>> # Use Case 1: Target a specific parameter (leaf) >>> prob_res_leaf = Probabilistic( ... model=res, ... distribution=Normal(loc=50.0, scale=1.0), ... target=lambda m: m.R ... ) >>> >>> # Use Case 2: Wrap the entire model (requires matching distribution tree) >>> import equinox as eqx >>> dist_tree = Joint(eqx.tree_at(lambda m: m.R, res, Normal(loc=50.0, scale=1.0))) >>> prob_res_tree = Probabilistic( ... model=res, ... distribution=dist_tree, ... )
Initialize the Probabilistic model.
- Parameters:
model (Model) – The base model to wrap.
distribution (AbstractDistribution) – The probability distribution to associate with the target. Must have the same JAX PyTree structure as model.
target (callable, optional) – A callable (lens) extracting the parameter to make probabilistic (e.g., lambda m: m.R). Defaults to the identity function, meaning the distribution applies to the entire model.
constraint (AbstractConstraint, optional) – An optional constraint for the distribution. Must have the same JAX PyTree structure as model. If None, it is inferred from the distribution.