Source code for pmrf.noise_models

"""
Models that represent the noise of a measurements, generally used in a likelihood.
"""
import jax.numpy as jnp
from typing import Callable

import parax as prx
from pmrf.core import NoiseModel


[docs] class ReflectionTransmissionNoise(NoiseModel): """ Reflection and transmission coefficient noise model. Maps underlying `gamma` and `tau` noises to a full matrix based on the specified port axes. Supports both circularly-symmetric underlying noise (returns a single array) and general complex noise (returns a tuple of (hermitian, pseudo)). """ gamma: prx.Parameter | Callable[[jnp.ndarray], jnp.ndarray | tuple[jnp.ndarray, jnp.ndarray]] tau: prx.Parameter | Callable[[jnp.ndarray], jnp.ndarray | tuple[jnp.ndarray, jnp.ndarray]] port_axes: tuple[int, int] = prx.field(static=True, default=(-2, -1)) def _build_matrix(self, val_gamma: jnp.ndarray, val_tau: jnp.ndarray, target_shape: tuple) -> jnp.ndarray: """Helper to map diagonal/off-diagonal values using robust broadcasting.""" # 1. Extract nports dynamically from the target shape ax1 = self.port_axes[0] % len(target_shape) ax2 = self.port_axes[1] % len(target_shape) nports = target_shape[ax1] if nports != target_shape[ax2]: raise ValueError( f"Dimensions at port_axes {self.port_axes} must be equal for a square matrix. " f"Got {target_shape[ax1]} and {target_shape[ax2]}." ) # 2. Create a base boolean identity matrix for the ports eye = jnp.eye(nports, dtype=bool) # 3. Reshape the identity matrix so it broadcasts correctly with target_shape eye_shape = [1] * len(target_shape) eye_shape[ax1] = nports eye_shape[ax2] = nports eye_broadcastable = eye.reshape(eye_shape) # 4. Use jnp.where to conditionally select gamma or tau return jnp.where(eye_broadcastable, val_gamma, val_tau) def __call__(self, y_pred: jnp.ndarray): val_gamma = self.gamma(y_pred) if callable(self.gamma) else self.gamma val_tau = self.tau(y_pred) if callable(self.tau) else self.tau is_gamma_tuple = isinstance(val_gamma, tuple) is_tau_tuple = isinstance(val_tau, tuple) if not is_gamma_tuple and not is_tau_tuple: # Standard real or circularly-symmetric complex case return self._build_matrix(val_gamma, val_tau, y_pred.shape) else: # General complex case: Route both Hermitian and Pseudo variances gamma_h = val_gamma[0] if is_gamma_tuple else val_gamma gamma_p = val_gamma[1] if is_gamma_tuple else jnp.zeros_like(gamma_h) tau_h = val_tau[0] if is_tau_tuple else val_tau tau_p = val_tau[1] if is_tau_tuple else jnp.zeros_like(tau_h) matrix_h = self._build_matrix(gamma_h, tau_h, y_pred.shape) matrix_p = self._build_matrix(gamma_p, tau_p, y_pred.shape) return matrix_h, matrix_p
[docs] class RadialTangentialNoise(NoiseModel): """ Radial/tangential complex-valued heteroscedastic variance noise model. Models noise as relative radial and tangential noise that scales with the magnitude of the signal. Requires that the underlying noise in `self.magnitude` and `self.phase` represents variance. Returns the hermitian and pseudo variance as a tuple. """ magnitude: prx.Parameter | jnp.ndarray | Callable phase: prx.Parameter | jnp.ndarray | Callable def __call__(self, y_pred: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: # 1. Calculate base variances mag = jnp.abs(y_pred) var_rad = (mag * self.magnitude)**2 var_tan = (mag * self.phase)**2 # 2. Total Hermitian Variance (Gamma) # This is real-valued variance = var_rad + var_tan # 3. Pseudo-Variance (C) # This captures the directionality/impropriety # phase = arg(y_pred), so exp(1j * 2 * phase) is (y_pred / mag)**2 unit_phase_sq = (y_pred / (mag + 1e-12))**2 pseudo_covariance = (var_rad - var_tan) * unit_phase_sq return variance, pseudo_covariance