DataLikelihood
- class DataLikelihood(predictor, data, likelihood, discrepancy=None, *, name=None)[source]
Bases:
EvaluatorComputes the log probability of observing given data given a likelihood function conditioned on a model prediction.
Allows for a probabilistic discrepancy model (e.g. a gaussian process), which returns the distribution over a model’s prediction given that prediction. The discrepancy model accepts the model prediction and scaled frequency vector, and returns a distribution over the model’s prediction.
- Parameters:
- data: Array
- flat_param_names(*args, **kwargs)
Return flattened parameter names as a list.
See [parax.Module.named_flat_params][].
- flat_param_values(*args, **kwargs)
Return flattened module parameter values as a jax arrays.
See [parax.Module.named_flat_param_values][].
- Return type:
Array
- flat_params(*args, **kwargs)
Return flattened parameters as a list.
See [parax.Module.named_flat_params][].
- Return type:
list[Parameter]
- func_jacobian(func, args)
Calculate the Jacobian of an arbitrary function with respect to free parameters.
This uses forward-mode automatic differentiation to compute the gradients of the provided function with respect to each free parameter in the module.
- Parameters:
func (Callable[[Module], jnp.ndarray]) – Function to differentiate. Must take a Module and args and return a jnp.ndarray of any shape.
args (Any) – The args to pass to func.
self (Self)
- Returns:
A dictionary mapping flat parameter names to their gradient arrays. Each array has the same shape as the output of func.
- Return type:
- func_samples(func, args, *, key, num_samples=1000)
Evaluates an arbitrary function over samples drawn from the module’s distribution.
- Parameters:
- Returns:
The function evaluated over all samples. Shape will be (num_samples, *func_output_shape).
- Return type:
jnp.ndarray
- func_sensitivity(func, args, kind='relative', norm=None)
Calculate the sensitivity of an arbitrary function w.r.t parameters.
Supported kinds: - ‘relative’: (dy/dtheta) * (theta/y). Fractional change in output per
fractional change in parameter. Blows up if y is zero.
‘semi-relative’: (dy/dtheta) * theta. Change in output per fractional change in parameter. Stable if y is zero.
‘absolute’: (dy/dtheta). Raw gradient.
- Parameters:
func (Callable[[Module], jnp.ndarray]) – Function to evaluate.
args (Any) – The args to pass to func.
kind (str, default='relative') – The type of sensitivity to calculate (‘relative’, ‘semi-relative’, ‘absolute’).
norm (int | str | None, default=None) – If provided, aggregates the parameter sensitivities into a single scalar metric using the specified norm (e.g., 2 for L2 norm, jnp.inf for max norm).
self (Self)
- Returns:
If norm is None, returns a dictionary mapping flat parameter names to sensitivity arrays. If norm is specified, returns a 0D scalar jax array representing the global sensitivity metric.
- Return type:
- iter_params(param_filter=None, *, include_fixed=False, flatten=False, submodules=None)
Iterate over (name, Parameter) pairs in internal order.
- merged(modules)
Merge this module with free parameters and parameter groups in other modules.
This is useful to combine separate modules obtained from fitting the same initial module with different free parameters.
- Parameters:
modules (Module or Sequence[Module]) – The other modules to combine this module with.
self (Self)
- Return type:
Module
- named_flat_param_values(scaled=False, return_floats=False, **kwargs)
Named flattened module parameter values as a dict of jax arrays.
See [parax.Module.named_flat_params][].
- named_flat_params(include_fixed=False, submodules=None)
Named flattened module parameters as a dict.
Flat parameters are a de-vectorized version of the internal parameters of the module. The returned parameter objects therefore are not necessarily equal to the internal module objects.
Keys are fully-qualified parameter names with de-vectorized suffixes added. The order matches the internal flattened array order.
- named_param_values(scaled=False, **kwargs)
Named module parameter values as a dict of jax arrays.
See [parax.Module.named_params][].
- named_params(param_filter=None, *, include_fixed=False, submodules=None)
Named module parameters as a dict.
Keys are fully-qualified parameter names. The order matches the internal flattened array order.
- Parameters:
param_filter (str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool], default=None) – A filter indicating which parameters to return. For the default case, all parameters are returned.
include_fixed (bool, default=False) – Include fixed parameters.
submodules (Module | Sequence[Module] | str | Sequence[str] | None, optional) – Restrict to parameters used by the given submodule(s). If strings are provided,
getattr(self, name)is used.
- Return type:
- param(name, *args, **kwargs)
Return a single module parameter by name.
See [parax.Module.named_params][].
- Parameters:
name (str)
- Return type:
Parameter
- param_groups(include_fixed=False, explicit_only=False)
Return all parameter groups relevant to this module, including submodules.
This function recursively traverses submodules to collect their parameter groups, adjusting parameter names to match the current module’s scope.
Priority is given to groups defined in the parent module. If a parameter is grouped explicitly in self._param_groups, it will be removed from any groups returned by submodules.
- param_names(*args, **kwargs)
Return module parameter names as a list.
See [parax.Module.named_params][].
- param_value(name, *args, **kwargs)
Return a single module parameter value by name as a single jax array.
See [parax.Module.named_param_values][].
- Parameters:
name (str)
- Return type:
Array
- param_values(*args, **kwargs)
Return module parameter values as a list of jax arrays.
See [parax.Module.named_param_values][].
- Return type:
list[Array]
- params(*args, **kwargs)
Return module parameters as a list.
See [parax.Module.named_params][].
- Return type:
list[Parameter]
- path_to_param_name(path)
Convert a PyTree path to a fully-qualified parameter name.
- Return type:
- sampled(key=None, **kwargs)
Returns a new module with parameters sampled from this parameter’s distribution.
- Return type:
Module
- submodules()
Returns all nested submodules (depth-first), excluding
self.- Return type:
list[Module]
- with_all_params_fixed(**kwargs)
Returns a module with all parameters fixed.
This is an alias for calling [parax.Module.with_free_params][] with fix_others=True and no parameters passed.
See [parax.Module.with_free_params][].
- with_all_params_free(**kwargs)
Returns a module with all parameters free.
This is an alias for calling [parax.Module.with_free_params][] with all parameters passed.
See [parax.Module.with_free_params][].
- with_attrs(*args, **kwargs)
Return a copy of the module with one or more attributes replaced.
This is similar to eqx.tree_at but uses string paths.
Usage
# 1. Single attribute update (path, value) model.with_attrs(‘a.b.c’, 10)
# 2. Batch nested updates via dictionary model.with_attrs({‘a.b.c’: 10, ‘x.y.z’: 20})
# 3. Top-level attributes via keyword arguments model.with_attrs(name=”new_model”, _transparent=True)
# 4. Combined dict and kwargs model.with_attrs({‘a.b.c’: 10}, name=”new_model”)
- classmethod with_defaults(*args, **kwargs)
Return this module type with default initialization arguments.
This method is very useful in utilizing an existing module with default values, without having to create a new module type via inheritance.
Arguments are forwarded as if they were passed to __init__.
- Return type:
type[Module]
- with_demoted_param_groups()
Recursively demote parameter groups to the deepest possible submodule.
This method identifies parameter groups where every parameter belongs to the same immediate submodule. It moves those groups to the submodule, stripping the prefix. It then recursively calls this method on the submodules to ensure groups continue moving down the hierarchy as far as possible.
- Returns:
A new module instance with parameter groups distributed to their lowest relevant submodules.
- Return type:
Self
- Parameters:
self (Self)
- with_fields(*args, **kwargs)
Return a copy of this module with dataclass-style field replacements.
Parameters are forwarded to dataclasses.replace.
- with_fixed_params(param_filter, free_others=False, **kwargs)
Return a module with specified parameters fixed.
This maps each parameter in the filter, calling [parax.Parameter.as_fixed][] on each.
See [parax.Module.with_mapped_params][].
- with_fixed_submodules(submodules)
Fix all parameters in the given submodules.
Submodules parameters are obtained using [parax.Module.param_names][]., and subsequently fixed using [parax.Module.with_fixed_params][].
- with_free_params(param_filter, *, fix_others=False, **kwargs)
Free the specified parameters.
This maps each parameter in the filter, calling [parax.Parameter.as_free][] on each.
See [parax.Module.with_mapped_params][].
- with_free_params_only(param_filter, **kwargs)
Returns a module with only the specified parameters freed.
This is an alias for calling [parax.Module.with_free_params][] with fix_others=True.
See [parax.Module.with_free_params][].
- with_free_submodules(submodules, fix_others=False, include_fixed=True)
Free all parameters in the given submodules.
Submodules parameters are obtained using [parax.Module.param_names][]., and subsequently freed using [parax.Module.with_free_params][].
- with_free_submodules_only(*args, include_fixed=False, **kwargs)
Returns a module with only the specified submodules freed.
This is an alias for calling [parax.Module.with_free_submodules][] with fix_others=True and include_fixed=False by default.
See [parax.Module.with_free_params][].
- with_mapped_distributions(mapper, dist_filter=None, *, map_others=None, param_groups=False)
Return a module with a function applied to its parameter distributions.
This method allows for bulk-updates of distributions, such as widening variances or changing distribution types.
If
param_groupsis False, the mapping is applied to the distributions of individual parameters (flattened).If
param_groupsis True, the mapping is applied to the distributions of [parax.ParameterGroup][] objects. This mode is recursive: it will traverse the module tree and apply the mapping to all explicit parameter groups in all submodules.- Parameters:
mapper (Callable[[AbstractDistribution], AbstractDistribution]) – Function that takes a distribution and returns a new one.
dist_filter (Callable[[AbstractDistribution], bool] | None, default=None) – A predicate function. If provided, the mapping is only applied to distributions where
dist_filter(dist)is True. If None, applies to all.map_others (Callable[[AbstractDistribution], AbstractDistribution] | None, default=None) – An optional map to apply to all distributions NOT in the filter.
param_groups (bool, default=False) – If True, map distributions on parameter groups (recursively). If False, map distributions on individual parameters (flat).
self (Self)
- Returns:
A new module with updated distributions.
- Return type:
Self
- with_mapped_params(mapper, param_filter=None, *, map_others=None, prefixes=False, include_fixed=False, ignore_unknown=False)
Return a module with specified parameters mapped.
- Parameters:
mapper (Callable[[Parameter], Parameter]) – The map to apply to each parameter in the filter (or all if no filter).
param_filter (str | Sequence[str] | Callable[[str], bool] | None, default=None) – Parameter names to map. If None, applies mapper to all parameters.
map_others (Callable[[Parameter], Parameter] | None, default=None) – An optional map to apply to all parameters NOT in the filter.
prefixes (bool, default=False) – Specifies that, when a string or list of strings is passed in param_filter, these must be interpreted as parameter prefixes to map and not full path names. Defaults to False.
self (Self)
include_fixed (bool)
ignore_unknown (bool)
- Return type:
Self
- with_name(name)
Return a copy of this module with a different name.
- with_no_param_groups()
Return a new module with all parameter groups removed recursively.
This clears the _param_groups of the current module and traverses all nested submodules (and sequences of submodules) to remove their parameter groups as well.
- Returns:
A new module instance with no parameter groups.
- Return type:
Self
- Parameters:
self (Self)
- with_param_groups(param_groups)
Return a module with parameter groups appended, replacing existing relationships.
This method implements an “atomic replacement” policy. If any parameter in an existing group is claimed by a new group, the entire existing group is removed.
This ensures that groups defining joint distributions are not left in an invalid broken state (e.g. having a dimension removed). Parameters that were in the removed group but not in the new group will revert to being ungrouped (handled by param_groups as singleton groups).
- with_params(params=None, check_missing=False, check_unknown=True, fix_others=False, include_fixed=False, **param_kwargs)
Return a new module with parameters updated.
This is a multi-purpose function that updates parameters differently based on the types pass.
- Parameters:
params (dict[str, Parameter] | dict[str, float] | jnp.ndarray | None, optional) – Parameter updates. If an array, all values must be provided (matching
flat_paramsorder). You may also pass keyword args.check_missing (bool, default=False) – Require that all module parameters are specified.
check_unknown (bool, default=True) – Error if unknown parameter keys are provided.
fix_others (bool, default=False) – Fix any parameters not explicitly passed.
include_fixed (bool, default=False) – Include fixed parameters when interpreting
paramsmapping.**param_kwargs (dict) – Additional parameter updates by name.
self (Self)
- Return type:
Self
- Raises:
Exception – If shape/order mismatches, unknown/missing names (when checked), or if arrays are found outside of Parameters.
- with_submodule_fields(submodule, *args, **kwargs)
Return a copy of this module with dataclass-style field replacements on a nested sub-module.
Parameters are forwarded to dataclasses.replace.
- with_submodules(*args, **kwargs)
Return a copy of the module with one or more submodules replaced.
This method accepts paths formatted in the exact same way as parameter names (e.g. ‘submodule1_submodule2_submodule3’), respecting transparency and custom names.
Usage
# Single replacement model.with_submodules(‘layer1_attention’, new_attention_module)
# Batch replacement model.with_submodules({
‘layer1_attention’: new_attn_1, ‘layer2_attention’: new_attn_2
})
- with_transformed_params(bijector, param_filter=None, **kwargs)
Return a module with a distreqx bijector applied to the specified parameters.
This utilizes the underlying transformed method on the matched Parameters, which updates their physical values, bounds, and distributions simultaneously while preserving the unconstrained latent values.
- with_uniform_distributions(percentage, param_filter=None, *, respect_bounds=False, remove_param_groups=True, zero_values='keep', **kwargs)
Return a module with uniform distributions set centered on current parameter values.
The distributions are defined with bounds calculated as
value * (1.0 +/- percentage).- Parameters:
percentage (float) – The fractional width of the uniform distribution (e.g. 0.1 = 10%).
param_filter (str | Sequence[str] | Callable[[str], bool], default=None) – The parameters to be updated with new uniform distributions. For the default case, all are updated.
respect_bounds (bool, default=False) – Whether or not the min and max bounds of the current distributions should be respected. If True, new bounds will not go larger than past these bounds.
remove_param_groups (bool, default=True) – Whether to remove parameter groups recursively when setting the uniform distributions. Otherwise, the joint distribution of the module may not be the desired uniform distribution.
zero_values (str, default='keep') – How to treat zero values. Currently the only option is to keep them and their bounds as is.
- Returns:
A new module with updated parameter distributions.
- Return type:
Self