ContinuousCallable
- class pmrf.models.adapters.callable.ContinuousCallable(fn: Any, theta: Any, *, name: str | None = None, metadata: Any = None, domain: str = 's', z0: Any = 50.0)
Bases:
AbstractSingleDomainA model that predicts its output at an arbitrary frequency using an arbitrary callable.
This class can be used to wrap external machine learning architectures (Equinox/other).
- Parameters:
fn (Callable[[jnp.ndarray], jnp.ndarray] | Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]) – The underlying callable model which predicts the response as a function of scaled frequency. May either be a function or a callable PyTree (e.g.
equinox.Module) which will be frozen. Must accept an array of shape (nfreq,) or (nfreq, nparams) depending on if theta is None, and return an array of shape (nfreq, nports, nports).theta (Param) – Parameters to pass to fn of length nparams. Can be None for models that contain their own
parax.Parameterobjects. All parameters, including fixed parameters, are passed.
- matrix(freq: Frequency) Array
Compute the intrinsic matrix for the domain.
- Parameters:
freq (Frequency) – The frequency grid to evaluate on.
- Returns:
The intrinsic domain matrix.
- Return type:
jax.numpy.ndarray
- fn: Callable[[Array], Array] | Callable[[Array, Array], Array]
The underlying callable model
- theta: AbstractVariable | Inexact[jaxlib._jax.Array, '...']
Parameters to pass to fn