"""
Non-uniform transmission lines.
"""
from typing import Callable, Any, Dict
import jax
import jax.numpy as jnp
import equinox as eqx
from parax import Parameter
from pmrf.core import Frequency, Model
from pmrf.models.components.lines.uniform import RLGCLine
from pmrf.rf import cascade_s
[docs]
class ProfiledLine(Model, transparent=True):
r"""
A non-uniform transmission line defined by an arbitrary profile.
This model wraps any RLGC transmission line model and allows its parameters
to vary as a function of length. For example, the line's characteristic impedance
can be varied exponentially for impedance matching purposes.
Any line parameters can be uniform across length or follow a spatial profile
defined by user-provided function (e.g., splines). Both the uniform parameters
and the coefficients of the profile functions are registered as `Parameter` objects,
making the profile parameters compatible with fitting and sampling.
Supported evaluation methods:
- 'stepped': A discrete cascaded approximation (default).
- 'riccati': A continuous ODE solver using the Matrix Riccati differential equations.
Example
--------
.. code-block:: python
import pmrf as prf
from pmrf.core import PhysicalLine, ProfiledLine
def linear_taper(t, start_val, end_val):
return start_val + (end_val - start_val) * t
tapered_line = ProfiledLine(
PhysicalLine,
linear_taper,
length=0.1,
zn={'start_val': 50.0, 'end_val': 100.0},
epr=2.2,
method='riccati',
options={'rtol': 1e-6, 'atol': 1e-6},
)
freq = prf.Frequency(start=1, stop=10, npoints=101, unit='ghz')
s_taper = tapered_line.s(freq)
"""
# Config
line_fn: Callable[[Any], RLGCLine] = eqx.field(static=True)
floating: bool = False
profile_fns: Dict[str, Callable] = eqx.field(static=True)
method: str = eqx.field(static=True)
options: dict = eqx.field(static=True)
# Parameters and sub-models
length: Parameter
profile_params: Dict[str, Dict[str, Parameter]]
uniform_params: Dict[str, Parameter]
def __init__(
self,
line_fn: Callable,
profile_fn: Callable | None = None,
*,
length: Any = 1,
floating: bool = False,
method: str = 'stepped',
options: dict | None = None,
name: str | None = None,
z0: complex = 50.0,
**line_params
):
super().__init__(name=name, z0=z0)
# Defaults
options = options or dict()
if method == 'stepped':
options.setdefault('N', 50)
elif method == 'riccati':
options.setdefault('rtol', 1e-5)
options.setdefault('atol', 1e-5)
options.setdefault('max_steps', 1000)
# Add a small offset to avoid evaluating the jacobian at exactly 0.0
# if the underlying line model has a singularity at length=0
options.setdefault('dz_eval', 0.0)
# Members
self.line_fn = line_fn
self.floating = floating
self.length = length
self.method = method
self.options = options
# 1. Instantiate a dummy model to extract all default parameters.
safe_kwargs = {
k: v for k, v in line_params.items()
if not isinstance(v, (dict, tuple))
}
safe_kwargs['length'] = self.length
dummy_model = line_fn(**safe_kwargs)
base_params = dummy_model.named_params(include_fixed=True)
# Extract the parameter objects, excluding 'length'
merged_params = {
k: v for k, v in base_params.items()
if k != 'length'
}
# 2. Overlay the user's raw inputs
merged_params.update(line_params)
parsed_profile_fns = {}
parsed_profile_params = {}
parsed_uniform_params = {}
# 3. Parse the unified kwargs into Equinox-friendly PyTrees
for k, v in merged_params.items():
if isinstance(v, tuple) and len(v) == 2 and callable(v[0]) and isinstance(v[1], dict):
parsed_profile_fns[k] = v[0]
parsed_profile_params[k] = {arg: val for arg, val in v[1].items()}
elif isinstance(v, dict):
if profile_fn is None:
raise ValueError(f"Parameter '{k}' was provided as a dict, but no `profile_fn` was specified.")
parsed_profile_fns[k] = profile_fn
parsed_profile_params[k] = {arg: val for arg, val in v.items()}
else:
parsed_uniform_params[k] = v
self.profile_fns = parsed_profile_fns
self.profile_params = parsed_profile_params
self.uniform_params = parsed_uniform_params
[docs]
@eqx.filter_jit
def section(self, dz: float, **profiled_params) -> RLGCLine:
# Create a fresh dict to avoid mutating self.uniform_params
current_params = dict(self.uniform_params)
current_params['length'] = dz
current_params['floating'] = self.floating
current_params.update(profiled_params)
return self.line_fn(**current_params)
def _profiled_params(self, t: float) -> dict:
"""Helper to evaluate all profile functions at normalized position t."""
evaluated_params = {}
for p_name, func in self.profile_fns.items():
f_kwargs = self.profile_params.get(p_name, {})
evaluated_params[p_name] = func(t, **f_kwargs)
return evaluated_params
[docs]
@eqx.filter_jit
def s(self, freq: Frequency) -> jnp.ndarray:
if self.method == 'stepped':
return self._s_stepped(freq)
elif self.method == 'riccati':
raise NotImplementedError("Riccati method for ProfileLine not yet functional")
return self._s_riccati(freq)
else:
raise ValueError(f"Unknown evaluation method: '{self.method}'")
@eqx.filter_jit
def _s_stepped(self, freq: Frequency) -> jnp.ndarray:
N = self.options['N']
dz = self.length / N
t_centers = (jnp.arange(N) + 0.5) / N
def s_section(t):
profiled_params = self._profiled_params(t)
sec = self.section(dz, **profiled_params)
return sec.s(freq), jnp.asarray(sec.z0, dtype=complex)
batch_s, batch_z0 = jax.vmap(s_section)(t_centers)
S_cas, z0_cas = cascade_s(batch_s, batch_z0)
return S_cas
@eqx.filter_jit
def _s_riccati(self, freq: Frequency) -> jnp.ndarray:
import diffrax
def s_derivative(t, S_state_tuple, args):
# 1. Reconstruct complex S-matrix from the real/imag tuple state
# This explicitly resolves the Diffrax complex dtype warning
S_real, S_imag = S_state_tuple
S_state = S_real + 1j * S_imag
# S_state shape: (..., 2N, 2N)
N_total = S_state.shape[-1]
N_ports = N_total // 2
# Unpack aggregated S-matrix blocks
S11 = S_state[..., :N_ports, :N_ports]
S12 = S_state[..., :N_ports, N_ports:]
S21 = S_state[..., N_ports:, :N_ports]
S22 = S_state[..., N_ports:, N_ports:]
profiled_params = self._profiled_params(t)
# Wrap the line construction to isolate `length`
def local_s_matrix(dz):
sec = self.section(dz, **profiled_params)
return sec.s(freq)
# 2. Extract Generator Matrix Q(z) = dS/dz using Finite Difference
# This bypasses the jax.jacfwd crash caused by complex matrix inversions/conjugations
# inside the renormalize_s power-wave formulation.
dz_delta = 1e-8
S_plus = local_s_matrix(dz_delta)
S_zero = local_s_matrix(0.0)
Q = (S_plus - S_zero) / dz_delta
Q11 = Q[..., :N_ports, :N_ports]
Q12 = Q[..., :N_ports, N_ports:]
Q21 = Q[..., N_ports:, :N_ports]
Q22 = Q[..., N_ports:, N_ports:]
# Differential Redheffer Star Product
dS11 = S12 @ Q11 @ S21
dS12 = S12 @ (Q12 + Q11 @ S22)
dS21 = (Q21 + S22 @ Q11) @ S21
dS22 = Q22 + Q21 @ S22 + S22 @ Q12 + S22 @ Q11 @ S22
# Reconstruct block matrix
top = jnp.concatenate([dS11, dS12], axis=-1)
bot = jnp.concatenate([dS21, dS22], axis=-1)
dS_dt = jnp.concatenate([top, bot], axis=-2)
dS_dt_scaled = self.length * dS_dt
# Return as a tuple of real/imag to maintain Diffrax stability
return jnp.real(dS_dt_scaled), jnp.imag(dS_dt_scaled)
# Establish Initial Condition S(0) = [0, I; I, 0]
profiled_params_init = self._profiled_params(0.0)
sec_dummy = self.section(0.0, **profiled_params_init)
S_dummy = sec_dummy.s(freq)
batch_shape = S_dummy.shape[:-2]
N_ports = S_dummy.shape[-1] // 2
I_mat = jnp.broadcast_to(jnp.eye(N_ports, dtype=complex), batch_shape + (N_ports, N_ports))
Z_mat = jnp.zeros_like(I_mat)
top_init = jnp.concatenate([Z_mat, I_mat], axis=-1)
bot_init = jnp.concatenate([I_mat, Z_mat], axis=-1)
S_initial_complex = jnp.concatenate([top_init, bot_init], axis=-2)
# Package initial state as a real/imaginary tuple
y0 = (jnp.real(S_initial_complex), jnp.imag(S_initial_complex))
# Setup and solve ODE
term = diffrax.ODETerm(s_derivative)
solver = diffrax.Dopri5()
stepsize_controller = diffrax.PIDController(
rtol=self.options['rtol'],
atol=self.options['atol']
)
solution = diffrax.diffeqsolve(
term,
solver,
t0=0.0,
t1=1.0,
dt0=0.1,
y0=y0,
stepsize_controller=stepsize_controller,
max_steps=self.options['max_steps']
)
# Unpack the final tuple state back into a standard complex S-matrix array
S_final_real, S_final_imag = solution.ys[0][-1], solution.ys[1][-1]
return S_final_real + 1j * S_final_imag