Terminated

class Terminated(from_model, into_model, *, name=None, z0=50 + 0j)[source]

Bases: Model

Represents one network terminated in another.

Parameters:
from_model: Model
into_model: Model
s(freq)[source]

Scattering parameter matrix.

If a different parameter type (a, z, y) is primary, this converts it to S.

Note that, in ParamRF, the power wave definition of S-parameters should be used. If you have a formulation in terms of another definition (such as traveling waves), simply use pmrf.rf.s2s() (or pmrf.rf.renormalize_s() if you need to change impedance too).

Parameters:

freq (Frequency) – Frequency grid.

Returns:

S-parameter matrix with shape (nf, n, n).

Return type:

jnp.ndarray

a(freq)

ABCD parameter matrix.

If a different parameter type is primary, this converts it to A.

Parameters:

freq (Frequency) – Frequency grid.

Returns:

ABCD matrix with shape (nf, 2, 2).

Return type:

jnp.ndarray

children()

Returns the immediate submodules.

Return type:

list[Module]

copy()

Returns a deepcopy of self.

Return type:

Module

Parameters:

self (Self)

export_touchstone(filename, frequency, sigma=0.0, **skrf_kwargs)

Export the model response to a Touchstone file via scikit-rf.

Parameters:
  • filename (str)

  • frequency (Frequency | skrf.Frequency)

  • sigma (float, default=0.0) – Additive complex noise std for S-parameters.

  • **skrf_kwargs – Forwarded to skrf.Network.write_touchstone().

Returns:

Return value of Network.write_touchstone.

Return type:

Any

flat_param_names(*args, **kwargs)

Return flattened parameter names as a list.

See [parax.Module.named_flat_params][].

Return type:

list[str]

flat_param_values(*args, **kwargs)

Return flattened module parameter values as a jax arrays.

See [parax.Module.named_flat_param_values][].

Return type:

Array

flat_params(*args, **kwargs)

Return flattened parameters as a list.

See [parax.Module.named_flat_params][].

Return type:

list[Parameter]

flipped(**kwargs)

Return a version of the model with ports flipped.

See pmrf.models.composite.transformed.Flipped.

Return type:

Model

func_jacobian(func, args)

Calculate the Jacobian of an arbitrary function with respect to free parameters.

This uses forward-mode automatic differentiation to compute the gradients of the provided function with respect to each free parameter in the module.

Parameters:
  • func (Callable[[Module], jnp.ndarray]) – Function to differentiate. Must take a Module and args and return a jnp.ndarray of any shape.

  • args (Any) – The args to pass to func.

  • self (Self)

Returns:

A dictionary mapping flat parameter names to their gradient arrays. Each array has the same shape as the output of func.

Return type:

dict[str, jnp.ndarray]

func_samples(func, args, *, key, num_samples=1000)

Evaluates an arbitrary function over samples drawn from the module’s distribution.

Parameters:
  • func (Callable[[Module], jnp.ndarray]) – A function that takes a Module instance and returns a JAX array.

  • args (Any) – The args to pass to func.

  • key (Array) – JAX random key for sampling.

  • num_samples (int, default=1000) – Number of modules to sample from the joint distribution.

Returns:

The function evaluated over all samples. Shape will be (num_samples, *func_output_shape).

Return type:

jnp.ndarray

func_sensitivity(func, args, kind='relative', norm=None)

Calculate the sensitivity of an arbitrary function w.r.t parameters.

Supported kinds: - ‘relative’: (dy/dtheta) * (theta/y). Fractional change in output per

fractional change in parameter. Blows up if y is zero.

  • ‘semi-relative’: (dy/dtheta) * theta. Change in output per fractional change in parameter. Stable if y is zero.

  • ‘absolute’: (dy/dtheta). Raw gradient.

Parameters:
  • func (Callable[[Module], jnp.ndarray]) – Function to evaluate.

  • args (Any) – The args to pass to func.

  • kind (str, default='relative') – The type of sensitivity to calculate (‘relative’, ‘semi-relative’, ‘absolute’).

  • norm (int | str | None, default=None) – If provided, aggregates the parameter sensitivities into a single scalar metric using the specified norm (e.g., 2 for L2 norm, jnp.inf for max norm).

  • self (Self)

Returns:

If norm is None, returns a dictionary mapping flat parameter names to sensitivity arrays. If norm is specified, returns a 0D scalar jax array representing the global sensitivity metric.

Return type:

dict[str, jnp.ndarray] | jnp.ndarray

iter_params(param_filter=None, *, include_fixed=False, flatten=False, submodules=None)

Iterate over (name, Parameter) pairs in internal order.

Parameters:
Return type:

Iterator[tuple[str, Parameter]]

merged(modules)

Merge this module with free parameters and parameter groups in other modules.

This is useful to combine separate modules obtained from fitting the same initial module with different free parameters.

Parameters:
  • modules (Module or Sequence[Module]) – The other modules to combine this module with.

  • self (Self)

Return type:

Module

name: str | None = None
named_flat_param_values(scaled=False, return_floats=False, **kwargs)

Named flattened module parameter values as a dict of jax arrays.

See [parax.Module.named_flat_params][].

Parameters:
  • scaled (bool, default=False) – Whether or not to scale the returned values by the parameter scales.

  • **kwargs – Additional key-word arguments as in [parax.Module.named_params][].

Return type:

dict[str, jnp.ndarray]

named_flat_params(include_fixed=False, submodules=None)

Named flattened module parameters as a dict.

Flat parameters are a de-vectorized version of the internal parameters of the module. The returned parameter objects therefore are not necessarily equal to the internal module objects.

Keys are fully-qualified parameter names with de-vectorized suffixes added. The order matches the internal flattened array order.

Parameters:
  • include_fixed (bool, default=False) – Include fixed parameters.

  • submodules (Module | Sequence[Module] | str | Sequence[str] | None, optional) – Restrict to parameters used by the given submodule(s). If strings are provided, getattr(self, name) is used.

Return type:

dict[str, Parameter]

named_param_values(scaled=False, **kwargs)

Named module parameter values as a dict of jax arrays.

See [parax.Module.named_params][].

Parameters:
  • scaled (bool, default=False) – Whether or not to scale the returned values by the parameter scales.

  • **kwargs – Additional key-word arguments as in [parax.Module.named_params][].

Return type:

dict[str, jnp.ndarray]

named_params(param_filter=None, *, include_fixed=False, submodules=None)

Named module parameters as a dict.

Keys are fully-qualified parameter names. The order matches the internal flattened array order.

Parameters:
  • param_filter (str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool], default=None) – A filter indicating which parameters to return. For the default case, all parameters are returned.

  • include_fixed (bool, default=False) – Include fixed parameters.

  • submodules (Module | Sequence[Module] | str | Sequence[str] | None, optional) – Restrict to parameters used by the given submodule(s). If strings are provided, getattr(self, name) is used.

Return type:

dict[str, Parameter]

property nports: int

Alias of number_of_ports.

property num_flat_params: int

Number of free, flattened parameters.

Return type:

int

property num_params: int

Number of free parameters.

Return type:

int

property number_of_ports: int

Number of ports.

Return type:

int

param(name, *args, **kwargs)

Return a single module parameter by name.

See [parax.Module.named_params][].

Parameters:

name (str)

Return type:

Parameter

param_groups(include_fixed=False, explicit_only=False)

Return all parameter groups relevant to this module, including submodules.

This function recursively traverses submodules to collect their parameter groups, adjusting parameter names to match the current module’s scope.

Priority is given to groups defined in the parent module. If a parameter is grouped explicitly in self._param_groups, it will be removed from any groups returned by submodules.

Parameters:

include_fixed (bool, default=False) – Include groups involving fixed parameters.

Return type:

list[ParameterGroup]

param_names(*args, **kwargs)

Return module parameter names as a list.

See [parax.Module.named_params][].

Return type:

list[str]

param_value(name, *args, **kwargs)

Return a single module parameter value by name as a single jax array.

See [parax.Module.named_param_values][].

Parameters:

name (str)

Return type:

Array

param_values(*args, **kwargs)

Return module parameter values as a list of jax arrays.

See [parax.Module.named_param_values][].

Return type:

list[Array]

params(*args, **kwargs)

Return module parameters as a list.

See [parax.Module.named_params][].

Return type:

list[Parameter]

path_to_param_name(path)

Convert a PyTree path to a fully-qualified parameter name.

Return type:

str

plot_func(func, freq, *, ax=None, label=None, color=None, **kwargs)

Evaluate and plot an arbitrary function of the current model.

This method evaluates the provided function using the model’s current parameter values and plots the resulting response over frequency.

Parameters:
  • func (Callable[[Model, Frequency], jnp.ndarray]) – Function to evaluate. Must take a Model and a Frequency object and return a jnp.ndarray of shape (n_freqs,).

  • freq (Frequency) – Frequency grid to evaluate over.

  • ax (matplotlib.axes.Axes, optional) – Axes to plot on. If None, the current axes (plt.gca()) are used.

  • label (str, optional) – Label for the plotted line (used in legends).

  • color (str, optional) – Color for the line. If None, uses the matplotlib color cycle.

  • **kwargs (dict) – Additional keyword arguments forwarded to matplotlib.pyplot.plot (e.g., linestyle, linewidth, alpha).

Returns:

The axes containing the plot.

Return type:

matplotlib.axes.Axes

plot_func_samples(func, freq, *, key, num_samples=1000, contours=True, ax=None, label=None, color='C0', alpha=0.1)

Evaluate and plot a function over samples from the parameter distribution.

This method draws samples from the model’s joint parameter distribution, evaluates the provided function for each sample, and plots the resulting responses over frequency.

Parameters:
  • func (Callable[[Model, Frequency], jnp.ndarray]) – Function to evaluate. Must take a Model and a Frequency object and return a jnp.ndarray of shape (n_freqs,).

  • freq (Frequency) – Frequency grid to evaluate over.

  • key (Array) – PRNG key for sampling the distribution.

  • num_samples (int, default=1000) – Number of samples to draw.

  • contours (bool, default=True) – If True, plots the mean response and filled contours corresponding to 1, 2, and 3 standard deviations. If False, plots all individual sample responses as transparent lines.

  • ax (matplotlib.axes.Axes, optional) – Axes to plot on. If None, the current axes (plt.gca()) are used.

  • label (str, optional) – Label for the mean line (used in legends).

  • color (str, default='C0') – Base color for the lines and shaded regions.

  • alpha (float, default=0.1) – Transparency of the individual lines (when contours=False).

Returns:

The axes containing the plot.

Return type:

matplotlib.axes.Axes

property port_tuples: list[tuple[int, int]]

All (m, n) port index pairs.

Return type:

list[tuple[int, int]]

primary(freq)

Dispatch to the primary function for the given frequency.

Parameters:

freq (Frequency)

Return type:

Array

property primary_function: Callable[[Frequency], Array]

The primary function (s or a) as a callable.

The primary function is the first overridden among PRIMARY_PROPERTIES, unless __call__ is overridden, in which case the primary function of the built model is returned.

Return type:

Callable[[Frequency], jnp.ndarray]

Raises:

NotImplementedError – If no primary property is overridden.

property primary_property: str

The primary property (e.g. "s", "a") as a string.

The primary property is the first overridden among PRIMARY_PROPERTIES, unless __call__ is overridden, in which case the primary property of the built model is returned.

Return type:

str

Raises:

NotImplementedError – If no primary property is overridden.

renumbered(from_ports, to_ports=None, **kwargs)

Return a version of the model with ports renumbered.

See pmrf.models.composite.transformed.Renumbered.

from_portstuple[int]

The original port indices that map to to_ports.

to_portstuple[int]

The new port indices.

Return type:

Model

Parameters:
sampled(key=None, **kwargs)

Returns a new module with parameters sampled from this parameter’s distribution.

Return type:

Module

saveable()
Parameters:

self (Self)

Return type:

Self

submodules()

Returns all nested submodules (depth-first), excluding self.

Return type:

list[Module]

terminated(load=None, **kwargs)

Returns a new model that contains this model terminated in another.

See pmrf.models.composite.transformed.Terminated.

Parameters:

load (Model | str, optional) – Load network. Can be ‘short’ or ‘open’ as aliases for SHORT and OPEN. Defaults to a SHORT.

Return type:

Model

to_skrf(frequency, sigma=0.0, **kwargs)

Convert the model at frequencies to an skrf.Network.

The active primary property (self.primary_property) is used.

Parameters:
  • frequency (pmrf.core.frequency | skrf.Frequency) – Frequency grid.

  • sigma (float, default=0.0) – If nonzero, add complex Gaussian noise with stdev sigma to s.

  • **kwargs – Forwarded to skrf.Network constructor.

Return type:

skrf.Network

with_all_params_fixed(**kwargs)

Returns a module with all parameters fixed.

This is an alias for calling [parax.Module.with_free_params][] with fix_others=True and no parameters passed.

See [parax.Module.with_free_params][].

Parameters:

self (Self)

Return type:

Self

with_all_params_free(**kwargs)

Returns a module with all parameters free.

This is an alias for calling [parax.Module.with_free_params][] with all parameters passed.

See [parax.Module.with_free_params][].

Parameters:

self (Self)

Return type:

Self

with_attrs(*args, **kwargs)

Return a copy of the module with one or more attributes replaced.

This is similar to eqx.tree_at but uses string paths.

Usage

# 1. Single attribute update (path, value) model.with_attrs(‘a.b.c’, 10)

# 2. Batch nested updates via dictionary model.with_attrs({‘a.b.c’: 10, ‘x.y.z’: 20})

# 3. Top-level attributes via keyword arguments model.with_attrs(name=”new_model”, _transparent=True)

# 4. Combined dict and kwargs model.with_attrs({‘a.b.c’: 10}, name=”new_model”)

Parameters:
Return type:

Self

classmethod with_defaults(*args, **kwargs)

Return this module type with default initialization arguments.

This method is very useful in utilizing an existing module with default values, without having to create a new module type via inheritance.

Arguments are forwarded as if they were passed to __init__.

Return type:

type[Module]

with_demoted_param_groups()

Recursively demote parameter groups to the deepest possible submodule.

This method identifies parameter groups where every parameter belongs to the same immediate submodule. It moves those groups to the submodule, stripping the prefix. It then recursively calls this method on the submodules to ensure groups continue moving down the hierarchy as far as possible.

Returns:

A new module instance with parameter groups distributed to their lowest relevant submodules.

Return type:

Self

Parameters:

self (Self)

with_fields(*args, **kwargs)

Return a copy of this module with dataclass-style field replacements.

Parameters are forwarded to dataclasses.replace.

Parameters:

self (Self)

Return type:

Self

with_fixed_params(param_filter, free_others=False, **kwargs)

Return a module with specified parameters fixed.

This maps each parameter in the filter, calling [parax.Parameter.as_fixed][] on each.

See [parax.Module.with_mapped_params][].

Parameters:
Return type:

Self

with_fixed_submodules(submodules)

Fix all parameters in the given submodules.

Submodules parameters are obtained using [parax.Module.param_names][]., and subsequently fixed using [parax.Module.with_fixed_params][].

Parameters:
  • submodules (Module | Sequence[Module] | str | Sequence[str]) – Submodules whose parameters should be fixed.

  • self (Self)

Return type:

Self

with_free_params(param_filter, *, fix_others=False, **kwargs)

Free the specified parameters.

This maps each parameter in the filter, calling [parax.Parameter.as_free][] on each.

See [parax.Module.with_mapped_params][].

Parameters:
Return type:

Self

with_free_params_only(param_filter, **kwargs)

Returns a module with only the specified parameters freed.

This is an alias for calling [parax.Module.with_free_params][] with fix_others=True.

See [parax.Module.with_free_params][].

Parameters:
Return type:

Self

with_free_submodules(submodules, fix_others=False, include_fixed=True)

Free all parameters in the given submodules.

Submodules parameters are obtained using [parax.Module.param_names][]., and subsequently freed using [parax.Module.with_free_params][].

Parameters:
  • submodules (Module | Sequence[Module] | str | Sequence[str]) – Submodules whose parameters should be free.

  • include_fixed (bool, default=True) – Include fixed parameters in the submodule.

  • fix_others (bool, default=False) – Fix all other submodules.

  • self (Self)

Return type:

Self

with_free_submodules_only(*args, include_fixed=False, **kwargs)

Returns a module with only the specified submodules freed.

This is an alias for calling [parax.Module.with_free_submodules][] with fix_others=True and include_fixed=False by default.

See [parax.Module.with_free_params][].

Parameters:

self (Self)

Return type:

Self

with_mapped_distributions(mapper, dist_filter=None, *, map_others=None, param_groups=False)

Return a module with a function applied to its parameter distributions.

This method allows for bulk-updates of distributions, such as widening variances or changing distribution types.

If param_groups is False, the mapping is applied to the distributions of individual parameters (flattened).

If param_groups is True, the mapping is applied to the distributions of [parax.ParameterGroup][] objects. This mode is recursive: it will traverse the module tree and apply the mapping to all explicit parameter groups in all submodules.

Parameters:
  • mapper (Callable[[AbstractDistribution], AbstractDistribution]) – Function that takes a distribution and returns a new one.

  • dist_filter (Callable[[AbstractDistribution], bool] | None, default=None) – A predicate function. If provided, the mapping is only applied to distributions where dist_filter(dist) is True. If None, applies to all.

  • map_others (Callable[[AbstractDistribution], AbstractDistribution] | None, default=None) – An optional map to apply to all distributions NOT in the filter.

  • param_groups (bool, default=False) – If True, map distributions on parameter groups (recursively). If False, map distributions on individual parameters (flat).

  • self (Self)

Returns:

A new module with updated distributions.

Return type:

Self

with_mapped_params(mapper, param_filter=None, *, map_others=None, prefixes=False, include_fixed=False, ignore_unknown=False)

Return a module with specified parameters mapped.

Parameters:
  • mapper (Callable[[Parameter], Parameter]) – The map to apply to each parameter in the filter (or all if no filter).

  • param_filter (str | Sequence[str] | Callable[[str], bool] | None, default=None) – Parameter names to map. If None, applies mapper to all parameters.

  • map_others (Callable[[Parameter], Parameter] | None, default=None) – An optional map to apply to all parameters NOT in the filter.

  • prefixes (bool, default=False) – Specifies that, when a string or list of strings is passed in param_filter, these must be interpreted as parameter prefixes to map and not full path names. Defaults to False.

  • self (Self)

  • include_fixed (bool)

  • ignore_unknown (bool)

Return type:

Self

with_name(name)

Return a copy of this module with a different name.

Parameters:
Return type:

Self

with_no_param_groups()

Return a new module with all parameter groups removed recursively.

This clears the _param_groups of the current module and traverses all nested submodules (and sequences of submodules) to remove their parameter groups as well.

Returns:

A new module instance with no parameter groups.

Return type:

Self

Parameters:

self (Self)

with_param_groups(param_groups)

Return a module with parameter groups appended, replacing existing relationships.

This method implements an “atomic replacement” policy. If any parameter in an existing group is claimed by a new group, the entire existing group is removed.

This ensures that groups defining joint distributions are not left in an invalid broken state (e.g. having a dimension removed). Parameters that were in the removed group but not in the new group will revert to being ungrouped (handled by param_groups as singleton groups).

Parameters:
  • param_groups (ParameterGroup or list[ParameterGroup]) – Group(s) to add.

  • self (Self)

Return type:

Self

with_params(params=None, check_missing=False, check_unknown=True, fix_others=False, include_fixed=False, **param_kwargs)

Return a new module with parameters updated.

This is a multi-purpose function that updates parameters differently based on the types pass.

Parameters:
  • params (dict[str, Parameter] | dict[str, float] | jnp.ndarray | None, optional) – Parameter updates. If an array, all values must be provided (matching flat_params order). You may also pass keyword args.

  • check_missing (bool, default=False) – Require that all module parameters are specified.

  • check_unknown (bool, default=True) – Error if unknown parameter keys are provided.

  • fix_others (bool, default=False) – Fix any parameters not explicitly passed.

  • include_fixed (bool, default=False) – Include fixed parameters when interpreting params mapping.

  • **param_kwargs (dict) – Additional parameter updates by name.

  • self (Self)

Return type:

Self

Raises:

Exception – If shape/order mismatches, unknown/missing names (when checked), or if arrays are found outside of Parameters.

with_submodule_fields(submodule, *args, **kwargs)

Return a copy of this module with dataclass-style field replacements on a nested sub-module.

Parameters are forwarded to dataclasses.replace.

Parameters:
  • submodule (str | Sequence[str]) – The name of the submodule (or sequence of names) to traverse. Can be a single string with a path e.g. ‘submodule1.submodule2’, or a list of submodules e.g. [‘submodule1’, ‘submodule2’].

  • self (Self)

Return type:

Self

with_submodules(*args, **kwargs)

Return a copy of the module with one or more submodules replaced.

This method accepts paths formatted in the exact same way as parameter names (e.g. ‘submodule1_submodule2_submodule3’), respecting transparency and custom names.

Usage

# Single replacement model.with_submodules(‘layer1_attention’, new_attention_module)

# Batch replacement model.with_submodules({

‘layer1_attention’: new_attn_1, ‘layer2_attention’: new_attn_2

})

Parameters:
Return type:

Self

with_transformed_params(bijector, param_filter=None, **kwargs)

Return a module with a distreqx bijector applied to the specified parameters.

This utilizes the underlying transformed method on the matched Parameters, which updates their physical values, bounds, and distributions simultaneously while preserving the unconstrained latent values.

Parameters:
  • bijector (distreqx.bijectors.AbstractBijector) – The bijector to apply.

  • param_filter (str | Sequence[str] | Callable[[str], bool] | None, default=None) – Parameter names to transform. If None, applies to all parameters.

  • self (Self)

Return type:

Self

with_uniform_distributions(percentage, param_filter=None, *, respect_bounds=False, remove_param_groups=True, zero_values='keep', **kwargs)

Return a module with uniform distributions set centered on current parameter values.

The distributions are defined with bounds calculated as value * (1.0 +/- percentage).

Parameters:
  • percentage (float) – The fractional width of the uniform distribution (e.g. 0.1 = 10%).

  • param_filter (str | Sequence[str] | Callable[[str], bool], default=None) – The parameters to be updated with new uniform distributions. For the default case, all are updated.

  • respect_bounds (bool, default=False) – Whether or not the min and max bounds of the current distributions should be respected. If True, new bounds will not go larger than past these bounds.

  • remove_param_groups (bool, default=True) – Whether to remove parameter groups recursively when setting the uniform distributions. Otherwise, the joint distribution of the module may not be the desired uniform distribution.

  • zero_values (str, default='keep') – How to treat zero values. Currently the only option is to keep them and their bounds as is.

Returns:

A new module with updated parameter distributions.

Return type:

Self

y(freq)

Admittance (Y) parameter matrix.

If a different parameter type is primary, this converts it to Y.

Parameters:

freq (Frequency) – Frequency grid.

Returns:

Y matrix with shape (nf, n, n).

Return type:

jnp.ndarray

z(freq)

Impedance (Z) parameter matrix.

If a different parameter type is primary, this converts it to Z.

Parameters:

freq (Frequency) – Frequency grid.

Returns:

Z matrix with shape (nf, n, n).

Return type:

jnp.ndarray

z0: complex = (50+0j)