ContinuousCallable

class pmrf.models.adapters.callable.ContinuousCallable(fn: Any, theta: Any, *, z0: complex = 50 + 0j, name: str | None = None, metadata: Any = None, kind: str = 's')

Bases: AbstractSingleProperty

A model that predicts its output at an arbitrary frequency using an arbitrary callable.

This class can be used to wrap external machine learning architectures (Equinox/Parax/other).

Parameters:
  • fn (Callable[[jnp.ndarray], jnp.ndarray] | Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]) – The underlying callable model which predicts the response as a function of scaled frequency. May either be a function or a callable PyTree (e.g. equinox.Module) which will be frozen. Must accept an array of shape (nfreq,) or (nfreq, nparams) depending on if theta is None, and return an array of shape (nfreq, nports, nports).

  • theta (Param) – Parameters to pass to fn of length nparams. Can be None for models that contain their own parax.Parameter objects. All parameters, including fixed parameters, are passed.

primary_matrix(freq: Frequency) Array

The primary matrix (e.g. s, a etc.) as a function of frequency.

The primary matrix represents the matrix returned by pmrf.Model.primary_property, which is either overridden by sub-classes, or is the first proprerty directly overriden out of pmrf.Model.s(), pmrf.Model.a(), pmrf.Model.y(), pmrf.Model.z() (in that order), unless :meth:pmrf.Model.build is overridden, in which case the primary matrix of the built model is returned.

This method can also be overriden itself in order to to dynamically implement one of the matrices as opposed to overriding it explicitly.

Return type:

jnp.ndarray

Raises:

NotImplementedError – If no primary property is overridden.

fn: Callable[[Array], Array] | Callable[[Array, Array], Array]

The underlying callable model

theta: AbstractVariable | Inexact[jaxlib._jax.Array, '...']

Parameters to pass to fn