RBFKernel

class RBFKernel(length_scale=1.0, *, name=None)[source]

Bases: Kernel

Radial Basis Function (Squared Exponential) kernel.

Variables:

length_scale (prx.Parameter) – Characteristic length scale of the correlation (default 1.0).

Parameters:
  • length_scale (Parameter)

  • name (str | None)

length_scale: Parameter = 1.0
children()

Returns the immediate submodules.

Return type:

list[Module]

copy()

Returns a deepcopy of self.

Return type:

Module

Parameters:

self (Self)

flat_param_names(*args, **kwargs)

Return flattened parameter names as a list.

See [parax.Module.named_flat_params][].

Return type:

list[str]

flat_param_values(*args, **kwargs)

Return flattened module parameter values as a jax arrays.

See [parax.Module.named_flat_param_values][].

Return type:

Array

flat_params(*args, **kwargs)

Return flattened parameters as a list.

See [parax.Module.named_flat_params][].

Return type:

list[Parameter]

func_jacobian(func, args)

Calculate the Jacobian of an arbitrary function with respect to free parameters.

This uses forward-mode automatic differentiation to compute the gradients of the provided function with respect to each free parameter in the module.

Parameters:
  • func (Callable[[Module], jnp.ndarray]) – Function to differentiate. Must take a Module and args and return a jnp.ndarray of any shape.

  • args (Any) – The args to pass to func.

  • self (Self)

Returns:

A dictionary mapping flat parameter names to their gradient arrays. Each array has the same shape as the output of func.

Return type:

dict[str, jnp.ndarray]

func_samples(func, args, *, key, num_samples=1000)

Evaluates an arbitrary function over samples drawn from the module’s distribution.

Parameters:
  • func (Callable[[Module], jnp.ndarray]) – A function that takes a Module instance and returns a JAX array.

  • args (Any) – The args to pass to func.

  • key (Array) – JAX random key for sampling.

  • num_samples (int, default=1000) – Number of modules to sample from the joint distribution.

Returns:

The function evaluated over all samples. Shape will be (num_samples, *func_output_shape).

Return type:

jnp.ndarray

func_sensitivity(func, args, kind='relative', norm=None)

Calculate the sensitivity of an arbitrary function w.r.t parameters.

Supported kinds: - ‘relative’: (dy/dtheta) * (theta/y). Fractional change in output per

fractional change in parameter. Blows up if y is zero.

  • ‘semi-relative’: (dy/dtheta) * theta. Change in output per fractional change in parameter. Stable if y is zero.

  • ‘absolute’: (dy/dtheta). Raw gradient.

Parameters:
  • func (Callable[[Module], jnp.ndarray]) – Function to evaluate.

  • args (Any) – The args to pass to func.

  • kind (str, default='relative') – The type of sensitivity to calculate (‘relative’, ‘semi-relative’, ‘absolute’).

  • norm (int | str | None, default=None) – If provided, aggregates the parameter sensitivities into a single scalar metric using the specified norm (e.g., 2 for L2 norm, jnp.inf for max norm).

  • self (Self)

Returns:

If norm is None, returns a dictionary mapping flat parameter names to sensitivity arrays. If norm is specified, returns a 0D scalar jax array representing the global sensitivity metric.

Return type:

dict[str, jnp.ndarray] | jnp.ndarray

iter_params(param_filter=None, *, include_fixed=False, flatten=False, submodules=None)

Iterate over (name, Parameter) pairs in internal order.

Parameters:
Return type:

Iterator[tuple[str, Parameter]]

merged(modules)

Merge this module with free parameters and parameter groups in other modules.

This is useful to combine separate modules obtained from fitting the same initial module with different free parameters.

Parameters:
  • modules (Module or Sequence[Module]) – The other modules to combine this module with.

  • self (Self)

Return type:

Module

name: str | None = None
named_flat_param_values(scaled=False, return_floats=False, **kwargs)

Named flattened module parameter values as a dict of jax arrays.

See [parax.Module.named_flat_params][].

Parameters:
  • scaled (bool, default=False) – Whether or not to scale the returned values by the parameter scales.

  • **kwargs – Additional key-word arguments as in [parax.Module.named_params][].

Return type:

dict[str, jnp.ndarray]

named_flat_params(include_fixed=False, submodules=None)

Named flattened module parameters as a dict.

Flat parameters are a de-vectorized version of the internal parameters of the module. The returned parameter objects therefore are not necessarily equal to the internal module objects.

Keys are fully-qualified parameter names with de-vectorized suffixes added. The order matches the internal flattened array order.

Parameters:
  • include_fixed (bool, default=False) – Include fixed parameters.

  • submodules (Module | Sequence[Module] | str | Sequence[str] | None, optional) – Restrict to parameters used by the given submodule(s). If strings are provided, getattr(self, name) is used.

Return type:

dict[str, Parameter]

named_param_values(scaled=False, **kwargs)

Named module parameter values as a dict of jax arrays.

See [parax.Module.named_params][].

Parameters:
  • scaled (bool, default=False) – Whether or not to scale the returned values by the parameter scales.

  • **kwargs – Additional key-word arguments as in [parax.Module.named_params][].

Return type:

dict[str, jnp.ndarray]

named_params(param_filter=None, *, include_fixed=False, submodules=None)

Named module parameters as a dict.

Keys are fully-qualified parameter names. The order matches the internal flattened array order.

Parameters:
  • param_filter (str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool], default=None) – A filter indicating which parameters to return. For the default case, all parameters are returned.

  • include_fixed (bool, default=False) – Include fixed parameters.

  • submodules (Module | Sequence[Module] | str | Sequence[str] | None, optional) – Restrict to parameters used by the given submodule(s). If strings are provided, getattr(self, name) is used.

Return type:

dict[str, Parameter]

property num_flat_params: int

Number of free, flattened parameters.

Return type:

int

property num_params: int

Number of free parameters.

Return type:

int

param(name, *args, **kwargs)

Return a single module parameter by name.

See [parax.Module.named_params][].

Parameters:

name (str)

Return type:

Parameter

param_groups(include_fixed=False, explicit_only=False)

Return all parameter groups relevant to this module, including submodules.

This function recursively traverses submodules to collect their parameter groups, adjusting parameter names to match the current module’s scope.

Priority is given to groups defined in the parent module. If a parameter is grouped explicitly in self._param_groups, it will be removed from any groups returned by submodules.

Parameters:

include_fixed (bool, default=False) – Include groups involving fixed parameters.

Return type:

list[ParameterGroup]

param_names(*args, **kwargs)

Return module parameter names as a list.

See [parax.Module.named_params][].

Return type:

list[str]

param_value(name, *args, **kwargs)

Return a single module parameter value by name as a single jax array.

See [parax.Module.named_param_values][].

Parameters:

name (str)

Return type:

Array

param_values(*args, **kwargs)

Return module parameter values as a list of jax arrays.

See [parax.Module.named_param_values][].

Return type:

list[Array]

params(*args, **kwargs)

Return module parameters as a list.

See [parax.Module.named_params][].

Return type:

list[Parameter]

path_to_param_name(path)

Convert a PyTree path to a fully-qualified parameter name.

Return type:

str

sampled(key=None, **kwargs)

Returns a new module with parameters sampled from this parameter’s distribution.

Return type:

Module

saveable()
Parameters:

self (Self)

Return type:

Self

submodules()

Returns all nested submodules (depth-first), excluding self.

Return type:

list[Module]

with_all_params_fixed(**kwargs)

Returns a module with all parameters fixed.

This is an alias for calling [parax.Module.with_free_params][] with fix_others=True and no parameters passed.

See [parax.Module.with_free_params][].

Parameters:

self (Self)

Return type:

Self

with_all_params_free(**kwargs)

Returns a module with all parameters free.

This is an alias for calling [parax.Module.with_free_params][] with all parameters passed.

See [parax.Module.with_free_params][].

Parameters:

self (Self)

Return type:

Self

with_attrs(*args, **kwargs)

Return a copy of the module with one or more attributes replaced.

This is similar to eqx.tree_at but uses string paths.

Usage

# 1. Single attribute update (path, value) model.with_attrs(‘a.b.c’, 10)

# 2. Batch nested updates via dictionary model.with_attrs({‘a.b.c’: 10, ‘x.y.z’: 20})

# 3. Top-level attributes via keyword arguments model.with_attrs(name=”new_model”, _transparent=True)

# 4. Combined dict and kwargs model.with_attrs({‘a.b.c’: 10}, name=”new_model”)

Parameters:
Return type:

Self

classmethod with_defaults(*args, **kwargs)

Return this module type with default initialization arguments.

This method is very useful in utilizing an existing module with default values, without having to create a new module type via inheritance.

Arguments are forwarded as if they were passed to __init__.

Return type:

type[Module]

with_demoted_param_groups()

Recursively demote parameter groups to the deepest possible submodule.

This method identifies parameter groups where every parameter belongs to the same immediate submodule. It moves those groups to the submodule, stripping the prefix. It then recursively calls this method on the submodules to ensure groups continue moving down the hierarchy as far as possible.

Returns:

A new module instance with parameter groups distributed to their lowest relevant submodules.

Return type:

Self

Parameters:

self (Self)

with_fields(*args, **kwargs)

Return a copy of this module with dataclass-style field replacements.

Parameters are forwarded to dataclasses.replace.

Parameters:

self (Self)

Return type:

Self

with_fixed_params(param_filter, free_others=False, **kwargs)

Return a module with specified parameters fixed.

This maps each parameter in the filter, calling [parax.Parameter.as_fixed][] on each.

See [parax.Module.with_mapped_params][].

Parameters:
Return type:

Self

with_fixed_submodules(submodules)

Fix all parameters in the given submodules.

Submodules parameters are obtained using [parax.Module.param_names][]., and subsequently fixed using [parax.Module.with_fixed_params][].

Parameters:
  • submodules (Module | Sequence[Module] | str | Sequence[str]) – Submodules whose parameters should be fixed.

  • self (Self)

Return type:

Self

with_free_params(param_filter, *, fix_others=False, **kwargs)

Free the specified parameters.

This maps each parameter in the filter, calling [parax.Parameter.as_free][] on each.

See [parax.Module.with_mapped_params][].

Parameters:
Return type:

Self

with_free_params_only(param_filter, **kwargs)

Returns a module with only the specified parameters freed.

This is an alias for calling [parax.Module.with_free_params][] with fix_others=True.

See [parax.Module.with_free_params][].

Parameters:
Return type:

Self

with_free_submodules(submodules, fix_others=False, include_fixed=True)

Free all parameters in the given submodules.

Submodules parameters are obtained using [parax.Module.param_names][]., and subsequently freed using [parax.Module.with_free_params][].

Parameters:
  • submodules (Module | Sequence[Module] | str | Sequence[str]) – Submodules whose parameters should be free.

  • include_fixed (bool, default=True) – Include fixed parameters in the submodule.

  • fix_others (bool, default=False) – Fix all other submodules.

  • self (Self)

Return type:

Self

with_free_submodules_only(*args, include_fixed=False, **kwargs)

Returns a module with only the specified submodules freed.

This is an alias for calling [parax.Module.with_free_submodules][] with fix_others=True and include_fixed=False by default.

See [parax.Module.with_free_params][].

Parameters:

self (Self)

Return type:

Self

with_mapped_distributions(mapper, dist_filter=None, *, map_others=None, param_groups=False)

Return a module with a function applied to its parameter distributions.

This method allows for bulk-updates of distributions, such as widening variances or changing distribution types.

If param_groups is False, the mapping is applied to the distributions of individual parameters (flattened).

If param_groups is True, the mapping is applied to the distributions of [parax.ParameterGroup][] objects. This mode is recursive: it will traverse the module tree and apply the mapping to all explicit parameter groups in all submodules.

Parameters:
  • mapper (Callable[[AbstractDistribution], AbstractDistribution]) – Function that takes a distribution and returns a new one.

  • dist_filter (Callable[[AbstractDistribution], bool] | None, default=None) – A predicate function. If provided, the mapping is only applied to distributions where dist_filter(dist) is True. If None, applies to all.

  • map_others (Callable[[AbstractDistribution], AbstractDistribution] | None, default=None) – An optional map to apply to all distributions NOT in the filter.

  • param_groups (bool, default=False) – If True, map distributions on parameter groups (recursively). If False, map distributions on individual parameters (flat).

  • self (Self)

Returns:

A new module with updated distributions.

Return type:

Self

with_mapped_params(mapper, param_filter=None, *, map_others=None, prefixes=False, include_fixed=False, ignore_unknown=False)

Return a module with specified parameters mapped.

Parameters:
  • mapper (Callable[[Parameter], Parameter]) – The map to apply to each parameter in the filter (or all if no filter).

  • param_filter (str | Sequence[str] | Callable[[str], bool] | None, default=None) – Parameter names to map. If None, applies mapper to all parameters.

  • map_others (Callable[[Parameter], Parameter] | None, default=None) – An optional map to apply to all parameters NOT in the filter.

  • prefixes (bool, default=False) – Specifies that, when a string or list of strings is passed in param_filter, these must be interpreted as parameter prefixes to map and not full path names. Defaults to False.

  • self (Self)

  • include_fixed (bool)

  • ignore_unknown (bool)

Return type:

Self

with_name(name)

Return a copy of this module with a different name.

Parameters:
Return type:

Self

with_no_param_groups()

Return a new module with all parameter groups removed recursively.

This clears the _param_groups of the current module and traverses all nested submodules (and sequences of submodules) to remove their parameter groups as well.

Returns:

A new module instance with no parameter groups.

Return type:

Self

Parameters:

self (Self)

with_param_groups(param_groups)

Return a module with parameter groups appended, replacing existing relationships.

This method implements an “atomic replacement” policy. If any parameter in an existing group is claimed by a new group, the entire existing group is removed.

This ensures that groups defining joint distributions are not left in an invalid broken state (e.g. having a dimension removed). Parameters that were in the removed group but not in the new group will revert to being ungrouped (handled by param_groups as singleton groups).

Parameters:
  • param_groups (ParameterGroup or list[ParameterGroup]) – Group(s) to add.

  • self (Self)

Return type:

Self

with_params(params=None, check_missing=False, check_unknown=True, fix_others=False, include_fixed=False, **param_kwargs)

Return a new module with parameters updated.

This is a multi-purpose function that updates parameters differently based on the types pass.

Parameters:
  • params (dict[str, Parameter] | dict[str, float] | jnp.ndarray | None, optional) – Parameter updates. If an array, all values must be provided (matching flat_params order). You may also pass keyword args.

  • check_missing (bool, default=False) – Require that all module parameters are specified.

  • check_unknown (bool, default=True) – Error if unknown parameter keys are provided.

  • fix_others (bool, default=False) – Fix any parameters not explicitly passed.

  • include_fixed (bool, default=False) – Include fixed parameters when interpreting params mapping.

  • **param_kwargs (dict) – Additional parameter updates by name.

  • self (Self)

Return type:

Self

Raises:

Exception – If shape/order mismatches, unknown/missing names (when checked), or if arrays are found outside of Parameters.

with_submodule_fields(submodule, *args, **kwargs)

Return a copy of this module with dataclass-style field replacements on a nested sub-module.

Parameters are forwarded to dataclasses.replace.

Parameters:
  • submodule (str | Sequence[str]) – The name of the submodule (or sequence of names) to traverse. Can be a single string with a path e.g. ‘submodule1.submodule2’, or a list of submodules e.g. [‘submodule1’, ‘submodule2’].

  • self (Self)

Return type:

Self

with_submodules(*args, **kwargs)

Return a copy of the module with one or more submodules replaced.

This method accepts paths formatted in the exact same way as parameter names (e.g. ‘submodule1_submodule2_submodule3’), respecting transparency and custom names.

Usage

# Single replacement model.with_submodules(‘layer1_attention’, new_attention_module)

# Batch replacement model.with_submodules({

‘layer1_attention’: new_attn_1, ‘layer2_attention’: new_attn_2

})

Parameters:
Return type:

Self

with_transformed_params(bijector, param_filter=None, **kwargs)

Return a module with a distreqx bijector applied to the specified parameters.

This utilizes the underlying transformed method on the matched Parameters, which updates their physical values, bounds, and distributions simultaneously while preserving the unconstrained latent values.

Parameters:
  • bijector (distreqx.bijectors.AbstractBijector) – The bijector to apply.

  • param_filter (str | Sequence[str] | Callable[[str], bool] | None, default=None) – Parameter names to transform. If None, applies to all parameters.

  • self (Self)

Return type:

Self

with_uniform_distributions(percentage, param_filter=None, *, respect_bounds=False, remove_param_groups=True, zero_values='keep', **kwargs)

Return a module with uniform distributions set centered on current parameter values.

The distributions are defined with bounds calculated as value * (1.0 +/- percentage).

Parameters:
  • percentage (float) – The fractional width of the uniform distribution (e.g. 0.1 = 10%).

  • param_filter (str | Sequence[str] | Callable[[str], bool], default=None) – The parameters to be updated with new uniform distributions. For the default case, all are updated.

  • respect_bounds (bool, default=False) – Whether or not the min and max bounds of the current distributions should be respected. If True, new bounds will not go larger than past these bounds.

  • remove_param_groups (bool, default=True) – Whether to remove parameter groups recursively when setting the uniform distributions. Otherwise, the joint distribution of the module may not be the desired uniform distribution.

  • zero_values (str, default='keep') – How to treat zero values. Currently the only option is to keep them and their bounds as is.

Returns:

A new module with updated parameter distributions.

Return type:

Self