Source code for pmrf.core.model

"""
The main model class.
"""

from typing import Callable, TYPE_CHECKING, Any

import jax
import jax.numpy as jnp
import equinox as eqx
import parax as prx

from pmrf.core.frequency import Frequency
from pmrf.rf import a2s, s2a, s2z, z2s, s2y, y2s
from pmrf.math import CONVERSION_LOOKUP
from pmrf.constants import PRIMARY_PROPERTIES
from pmrf.utils import is_overridden

if TYPE_CHECKING:
    import skrf

Z0_WARNING = \
r"""
WARNING: You have created a model with characteristic impedance other than 50 ohm.
Working with multiple models in ParamRF with differing characteristic impedances
is not yet officially supported and you may encounter subtle bugs. For now, it is
recommended to keep the default z0 and convert your results at the end.
"""

class Model(prx.Module):
    """
    Base class for RF models.

    This base class is used to represent any computable RF network, referred to in
    **ParamRF** as a "Model". This class can be overridden for defining complex models,
    or can be utilized indirectly by combining models already provided in :mod:`pmrf.models`.
    
    This class is abstract and should not be instantiated directly. Derive from :class:`Model`
    and override one of the primary property functions (e.g. :meth:`.__call__`, :meth:`.s`, :meth:`.a`).

    The model is a Parax/Equinox ``Module`` (immutable, dataclass-like) and is
    treated as a JAX PyTree. Parameters are declared using standard dataclass
    field syntax using `parax.Parameter`.

    Usage
    -----
    - Define new models by sub-classing the model and adding custom parameters and/or sub-models
    - Construct models by passing parameters and/or submodels to the initializer (like a dataclass).
    - Use "past tense" functions to modify the model in conjunction with another model or data e.g. :meth:`.terminated`, :meth:`.flipped`.
    - Retrieve parameter information via `Parax` methods such as :meth:`.named_params`, :meth:`.param_names`, :meth:`.flat_params`, etc.
    - Use **Parax** ``with_xxx`` functions to modify fields, models and parameters within the model e.g. :meth:`.with_params`, :meth:`.with_fields`.    

    Methods & Properties Summary
    ----------------------------

    **Core API**

    ================================= ====================================================================
    Method                            Description
    ================================= ====================================================================
    :meth:`__call__`                  Build the model. Should be overridden by sub-classes.
    :meth:`s`                         Scattering (S) parameter matrix.
    :meth:`a`                         ABCD parameter matrix.
    :meth:`z`                         Impedance (Z) parameter matrix.
    :meth:`y`                         Admittance (Y) parameter matrix.
    :meth:`primary`                   Dispatch to the primary function for the given frequency.
    :attr:`primary_function`          The primary function (``s`` or ``a``) as a callable.
    :attr:`primary_property`          The primary property (e.g. ``"s"``, ``"a"``) as a string.
    :attr:`number_of_ports`           Number of ports.
    :attr:`nports`                    Alias of :attr:`number_of_ports`.
    :attr:`port_tuples`               All (m, n) port index pairs.
    ================================= ====================================================================

    **Model Manipulation**

    ================================= ====================================================================
    Method                            Description
    ================================= ====================================================================
    :meth:`flipped`                   Return a version of the model with ports flipped.
    :meth:`renumbered`                Return a version of the model with ports renumbered.
    :meth:`terminated`                Return a new model terminated by another (e.g. load).
    ================================= ====================================================================

    **Plotting, File, & Conversion Utilities**

    ================================= ====================================================================
    Method                            Description
    ================================= ====================================================================
    :meth:`plot_func`                 Evaluate and plot an arbitrary function of the model.
    :meth:`plot_func_samples`         Evaluate and plot a function over parameter samples.
    :meth:`to_skrf`                   Convert the model at frequencies to an :class:`skrf.Network`.
    :meth:`export_touchstone`         Export the model response to a Touchstone file.
    ================================= ====================================================================    

    Examples
    --------
    A ``PiCLC`` network ("foundational" model with fixed parameters and equations):

    .. code-block:: python

        import jax.numpy as jnp
        import parax as prx
        import pmrf as prf        

        class PiCLC(prf.Model):
            C1: prx.Parameter = 1.0e-12
            L:  prx.Parameter = 1.0e-9
            C2: prx.Parameter = 1.0e-12

            def a(self, freq: prf.Frequency) -> jnp.ndarray:
                w = freq.w
                Y1, Y2, Y3 = (1j * w * self.C1), (1j * w * self.C2), 1 / (1j * w * self.L)
                return jnp.array([
                    [1 + Y2 / Y3,        1 / Y3],
                    [Y1 + Y2 + Y1*Y2/Y3, 1 + Y1 / Y3],
                ]).transpose(2, 0, 1)

    An ``RLC`` network ("circuit" model with free parameters built using cascading)

    .. code-block:: python

        import pmrf as prf
        from pmrf.core import Resistor, Capacitor, Inductor, Cascade
        from parax.parameters import Uniform

        class RLC(prf.Model):
            res: Resistor = Resistor(Uniform(9.0, 11.0))
            ind: Inductor = Inductor(Uniform(0.0, 10.0, scale=1e-9))
            cap: Capacitor = Capacitor(Uniform(0.0, 10.0, scale=1e-12))

            def __call__(self) -> prf.Model:
                return self.res ** self.ind ** self.cap.terminated()
            
    """
    z0: complex = prx.field(default=50.0+0j, kw_only=True, static=True)
    
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)        
            
        # --- Implement dynamic functions (s_mag, s_mn_mag, etc.) ---
        def make_dynamic_method(prop_name, func):
            def dynamic_method(self, *args, **kwargs):
                matrix = getattr(self, prop_name)(*args, **kwargs)
                return func(matrix)
            return dynamic_method
            
        for prop in PRIMARY_PROPERTIES:
            for suffix, lookup in CONVERSION_LOOKUP.items():
                func = lookup[1]
                
                # Base function (e.g. s_mag)
                func_name = f"{prop}_{suffix}"
                if not hasattr(cls, func_name):  # Protect user overrides!
                    m = make_dynamic_method(prop, func)
                    m._pmrf_auto = True
                    setattr(cls, func_name, m)
                
                # Indexed function (e.g. s_mn_mag)
                func_name_mn = f"{prop}_mn_{suffix}"
                if not hasattr(cls, func_name_mn):
                    m_mn = make_dynamic_method(f"{prop}_mn", func)
                    m_mn._pmrf_auto = True
                    setattr(cls, func_name_mn, m_mn)

    # ---- Defaults / Primary ---------------------------------------------------    
    
    @property
    def primary_function(self) -> Callable[[Frequency], jnp.ndarray]:
        """The primary function (``s`` or ``a``) as a callable.

        The primary function is the first overridden among
        :data:`PRIMARY_PROPERTIES`, unless ``__call__`` is overridden,
        in which case the primary function of the built model is returned.

        Returns
        -------
        Callable[[Frequency], jnp.ndarray]

        Raises
        ------
        NotImplementedError
            If no primary property is overridden.
        """
        return getattr(self, self.primary_property)
            
    @property
    def primary_property(self) -> str:
        """The primary property (e.g. ``"s"``, ``"a"``) as a string.

        The primary property is the first overridden among
        :data:`PRIMARY_PROPERTIES`, unless ``__call__`` is overridden,
        in which case the primary property of the built model is returned.

        Returns
        -------
        str

        Raises
        ------
        NotImplementedError
            If no primary property is overridden.
        """
        prioritized = () # for future expansion
        unprioritized = tuple(p for p in PRIMARY_PROPERTIES if p not in prioritized)

        if is_overridden(type(self), Model, '__call__'):
            return self().primary_property
        
        for property in prioritized:
            if is_overridden(type(self), Model, property):
                return property
        for property in unprioritized:
            if is_overridden(type(self), Model, property):
                return property
        raise NotImplementedError(f"No primary properties in {PRIMARY_PROPERTIES} are overridden, which are the only ones supported currently")    

    # ---- Introspection properties --------------------------------------------------------
    
    @property
    def number_of_ports(self) -> int:
        """Number of ports.

        Returns
        -------
        int
        """
        freq = Frequency(1, 2, 2)
        eval = jax.eval_shape(lambda: self.s(freq))
        return eval.shape[1]

    @property
    def nports(self) -> int:
        """Alias of :attr:`number_of_ports`."""
        return self.number_of_ports
    
    @property
    def port_tuples(self) -> list[tuple[int, int]]:
        """All (m, n) port index pairs.

        Returns
        -------
        list[tuple[int, int]]
        """
        return [(y, x) for x in range(self.nports) for y in range(self.nports)]
    
    # ---- Core API -------------------------------------------------------------
    
    def __call__(self) -> 'Model':
        """Build the model.

        This function should be over-ridden by sub-classes.
        It is useful in defining complex models that are built
        using several sub-models (as opposed to equation-based models).

        Returns
        -------
        Model

        Raises
        ------
        NotImplementedError
            In the base class; override in derived classes to build
            a compositional representation.
        """     
        raise NotImplementedError
    
[docs] @eqx.filter_jit def primary(self, freq: Frequency) -> jnp.ndarray: """Dispatch to the primary function for the given frequency.""" primary_function = self.primary_function return primary_function(freq)
[docs] @eqx.filter_jit def s(self, freq: Frequency) -> jnp.ndarray: """Scattering parameter matrix. If a different parameter type (a, z, y) is primary, this converts it to S. Note that, in ParamRF, the **power wave** definition of S-parameters should be used. If you have a formulation in terms of another definition (such as traveling waves), simply use :meth:`pmrf.rf.s2s` (or :meth:`pmrf.rf.renormalize_s` if you need to change impedance too). Parameters ---------- freq : Frequency Frequency grid. Returns ------- jnp.ndarray S-parameter matrix with shape ``(nf, n, n)``. """ if is_overridden(type(self), Model, '__call__'): return self().s(freq) # 1. Fetch primary primary_prop = self.primary_property val = self.primary(freq) # 2. Return or Convert if primary_prop == 's': return val elif primary_prop == 'a': return a2s(val, self.z0) elif primary_prop == 'z': return z2s(val, self.z0) elif primary_prop == 'y': return y2s(val, self.z0) raise NotImplementedError(f"Conversion from '{primary_prop}' to 's' is not implemented.")
[docs] @eqx.filter_jit def a(self, freq: Frequency) -> jnp.ndarray: """ABCD parameter matrix. If a different parameter type is primary, this converts it to A. Parameters ---------- freq : Frequency Frequency grid. Returns ------- jnp.ndarray ABCD matrix with shape ``(nf, 2, 2)``. """ if is_overridden(type(self), Model, '__call__'): return self().a(freq) # 1. Fetch primary primary_prop = self.primary_property val = self.primary(freq) # 2. Return or Convert if primary_prop == 'a': return val # Convert via S parameters (Hub strategy) if primary_prop == 's': s = val elif primary_prop == 'z': s = z2s(val, self.z0) elif primary_prop == 'y': s = y2s(val, self.z0) else: raise NotImplementedError(f"Conversion from '{primary_prop}' to 'a' is not implemented.") return s2a(s, self.z0)
[docs] @eqx.filter_jit def z(self, freq: Frequency) -> jnp.ndarray: """Impedance (Z) parameter matrix. If a different parameter type is primary, this converts it to Z. Parameters ---------- freq : Frequency Frequency grid. Returns ------- jnp.ndarray Z matrix with shape ``(nf, n, n)``. """ if is_overridden(type(self), Model, '__call__'): return self().z(freq) # 1. Fetch primary primary_prop = self.primary_property val = self.primary(freq) # 2. Return or Convert if primary_prop == 'z': return val # Convert via S parameters (Hub strategy) if primary_prop == 's': s = val elif primary_prop == 'a': s = a2s(val, self.z0) elif primary_prop == 'y': s = y2s(val, self.z0) else: raise NotImplementedError(f"Conversion from '{primary_prop}' to 'z' is not implemented.") return s2z(s, self.z0)
[docs] @eqx.filter_jit def y(self, freq: Frequency) -> jnp.ndarray: """Admittance (Y) parameter matrix. If a different parameter type is primary, this converts it to Y. Parameters ---------- freq : Frequency Frequency grid. Returns ------- jnp.ndarray Y matrix with shape ``(nf, n, n)``. """ if is_overridden(type(self), Model, '__call__'): return self().y(freq) # 1. Fetch primary primary_prop = self.primary_property val = self.primary(freq) # 2. Return or Convert if primary_prop == 'y': return val # Convert via S parameters (Hub strategy) if primary_prop == 's': s = val elif primary_prop == 'a': s = a2s(val, self.z0) elif primary_prop == 'z': s = z2s(val, self.z0) else: raise NotImplementedError(f"Conversion from '{primary_prop}' to 'y' is not implemented.") return s2y(s, self.z0)
# ---- Magic methods and copying -------------------------------------------------- def __getattr__(self, name: str): """ Dynamic dispatch for scikit-rf plotting methods. Captures calls like `model.plot_s_db(freq)` and redirects them to `model.to_skrf(freq).plot_s_db()`. """ if name.startswith('plot_'): def plotter(freq: Frequency, *args, **kwargs): # 1. Convert to scikit-rf Network at the specified frequency ntwk = self.to_skrf(freq) # 2. Check if the generated Network actually supports this plot type if not hasattr(ntwk, name): raise AttributeError(f"scikit-rf Network object has no attribute '{name}'") # 3. Call the scikit-rf plot method with remaining args (e.g. labels, colors) return getattr(ntwk, name)(*args, **kwargs) return plotter # Standard fallback if the attribute isn't a plot command raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def __pow__(self, other: 'Model') -> 'Model': """Cascade/termination composition operator ``**``.""" if other.nports == 1: from pmrf.models import Terminated return Terminated(self, other) else: from pmrf.models import Cascade return Cascade([self, other])
[docs] def flipped(self, **kwargs) -> 'Model': """Return a version of the model with ports flipped. See :class:`pmrf.models.composite.transformed.Flipped`. Returns ------- Model """ from pmrf.models import Flipped if isinstance(self, Flipped): return self.model return Flipped(self, **kwargs)
[docs] def renumbered(self, from_ports: tuple[int], to_ports: tuple[int]= None, **kwargs) -> 'Model': """Return a version of the model with ports renumbered. See :class:`pmrf.models.composite.transformed.Renumbered`. from_ports : tuple[int] The original port indices that map to `to_ports`. to_ports : tuple[int] The new port indices. Returns ------- Model """ from pmrf.models import Renumbered return Renumbered(self, from_ports, to_ports, **kwargs)
[docs] def terminated(self, load: 'Model' = None, **kwargs) -> 'Model': """Returns a new model that contains this model terminated in another. See :class:`pmrf.models.composite.transformed.Terminated`. Parameters ---------- load : Model | str, optional Load network. Can be 'short' or 'open' as aliases for SHORT and OPEN. Defaults to a SHORT. Returns ------- Model """ from pmrf.models import SHORT, OPEN, Terminated if isinstance(load, str): if load == 'short': load = SHORT elif load == 'open': load = OPEN else: raise ValueError(f"Unknown load alias {load} received in 'Model.terminated()'") load = load or SHORT return Terminated(self, load, **kwargs)
# ---- Plotting --------------------------------------------------
[docs] def plot_func( self, func: Callable[['Model', Frequency], jnp.ndarray], freq: Frequency, *, ax = None, label: str | None = None, color: str | None = None, **kwargs ): """Evaluate and plot an arbitrary function of the current model. This method evaluates the provided function using the model's current parameter values and plots the resulting response over frequency. Parameters ---------- func : Callable[[Model, Frequency], jnp.ndarray] Function to evaluate. Must take a Model and a Frequency object and return a jnp.ndarray of shape (n_freqs,). freq : Frequency Frequency grid to evaluate over. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, the current axes (`plt.gca()`) are used. label : str, optional Label for the plotted line (used in legends). color : str, optional Color for the line. If None, uses the matplotlib color cycle. **kwargs : dict Additional keyword arguments forwarded to `matplotlib.pyplot.plot` (e.g., `linestyle`, `linewidth`, `alpha`). Returns ------- matplotlib.axes.Axes The axes containing the plot. """ import matplotlib.pyplot as plt import numpy as np if ax is None: ax = plt.gca() # 1. Evaluate the function on the current model y_val = func(self, freq) y_val = np.asarray(y_val) # Extract the x-axis automatically from the frequency object x_axis = np.asarray(freq.f_scaled) # 2. Plotting logic # Assemble kwargs safely to avoid passing multiple 'color' or 'label' arguments plot_kwargs = kwargs.copy() if label is not None: plot_kwargs['label'] = label if color is not None: plot_kwargs['color'] = color ax.plot(x_axis, y_val, **plot_kwargs) return ax
[docs] def plot_func_samples( self, func: Callable[['Model', Frequency], jnp.ndarray], freq: Frequency, *, key: jax.Array, num_samples: int = 1000, contours: bool = True, ax = None, label: str | None = None, color: str = 'C0', alpha: float = 0.1, ): """Evaluate and plot a function over samples from the parameter distribution. This method draws samples from the model's joint parameter distribution, evaluates the provided function for each sample, and plots the resulting responses over frequency. Parameters ---------- func : Callable[[Model, Frequency], jnp.ndarray] Function to evaluate. Must take a Model and a Frequency object and return a jnp.ndarray of shape (n_freqs,). freq : Frequency Frequency grid to evaluate over. key : jax.Array PRNG key for sampling the distribution. num_samples : int, default=1000 Number of samples to draw. contours : bool, default=True If True, plots the mean response and filled contours corresponding to 1, 2, and 3 standard deviations. If False, plots all individual sample responses as transparent lines. ax : matplotlib.axes.Axes, optional Axes to plot on. If None, the current axes (`plt.gca()`) are used. label : str, optional Label for the mean line (used in legends). color : str, default='C0' Base color for the lines and shaded regions. alpha : float, default=0.1 Transparency of the individual lines (when `contours=False`). Returns ------- matplotlib.axes.Axes The axes containing the plot. """ import matplotlib.pyplot as plt import numpy as np if ax is None: ax = plt.gca() # 1. Evaluate the ensemble y_samples = self.func_samples(func, freq, key=key, num_samples=num_samples) y_samples = np.asarray(y_samples) # Extract the x-axis automatically from the frequency object x_axis = np.asarray(freq.f_scaled) # 2. Calculate central tendency y_mean = np.mean(y_samples, axis=0) # 3. Plotting logic if not contours: # Plot all individual samples # Transpose y_samples so matplotlib interprets columns as individual lines ax.plot(x_axis, y_samples.T, color=color, alpha=alpha) # Plot the mean as a solid line on top ax.plot(x_axis, y_mean, color=color, label=label, linewidth=2) else: # Plot mean line ax.plot(x_axis, y_mean, color=color, label=label, linewidth=2) # Plot contours for 1, 2, and 3 standard deviations y_std = np.std(y_samples, axis=0) # Decreasing opacity for outer standard deviations for i, sig_alpha in zip([1, 2, 3], [0.3, 0.2, 0.1]): ax.fill_between( x_axis, y_mean - i * y_std, y_mean + i * y_std, color=color, alpha=sig_alpha, linewidth=0 ) return ax
# ---- File and conversion utilities --------------------------------------------------
[docs] def to_skrf(self, frequency: Frequency | Any, sigma=0.0, **kwargs) -> 'skrf.Network': """Convert the model at frequencies to an :class:`skrf.Network`. The active primary property (``self.primary_property``) is used. Parameters ---------- frequency : pmrf.core.frequency | skrf.Frequency Frequency grid. sigma : float, default=0.0 If nonzero, add complex Gaussian noise with stdev ``sigma`` to ``s``. **kwargs Forwarded to :class:`skrf.Network` constructor. Returns ------- skrf.Network """ import skrf import numpy as np if isinstance(frequency, Frequency): model_freq = frequency measured_freq = frequency.to_skrf() else: model_freq = Frequency.from_skrf(frequency) measured_freq = frequency fval, fname = self.primary(model_freq), self.primary_property kwargs = kwargs or {} kwargs.update({ fname: np.array(fval), 'frequency': measured_freq, 'name': kwargs.get('name', self.name), 'z0': self.z0, }) ntwk = skrf.Network(**kwargs) if sigma != 0.0: ntwk.s += (np.random.normal(0, sigma, ntwk.s.shape) + 1j * np.random.normal(0, sigma, ntwk.s.shape)) return ntwk
[docs] def export_touchstone(self, filename: str, frequency: Frequency | Any, sigma: float = 0.0, **skrf_kwargs): """Export the model response to a Touchstone file via scikit-rf. Parameters ---------- filename : str frequency : Frequency | skrf.Frequency sigma : float, default=0.0 Additive complex noise std for S-parameters. **skrf_kwargs Forwarded to :meth:`skrf.Network.write_touchstone`. Returns ------- Any Return value of ``Network.write_touchstone``. """ import skrf if not isinstance(filename, str): raise Exception('Filename must be a string') ntwk = self.to_skrf(frequency, sigma=sigma) return ntwk.write_touchstone(filename, **skrf_kwargs)