TargetLoss

class pmrf.evaluators.TargetLoss(predictor: Callable[[Model, Frequency], jnp.ndarray], target: Any, loss: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray])

Bases: AbstractEvaluator

Computes a loss between a model prediction and some target.

Parameters:
  • predictor (Callable[[pmrf.models.base.Model, pmrf.frequency.Frequency], jax.jaxlib._jax.Array]) – The predictor (e.g. another Evaluator) that extracts model features. Can be a function or a PyTree with optional parameters.

  • target (jax.jaxlib._jax.Array) – The fixed or ‘true’ target that the loss function should compare the prediction to.

  • loss (Callable[[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], jax.jaxlib._jax.Array]) – The loss function that takes (y_true, y_pred) and returns a loss metric. Can be a function or a PyTree with optional parameters. See pmrf.losses for common losses.

__call__(model: Model, frequency: Frequency, **kwargs) Array

Evaluate the model response over the specified frequency range.

Parameters:
  • model (Model) – The model instance to evaluate.

  • freq (Frequency) – The frequency object defining the evaluation points.

  • **kwargs (dict) – Additional keyword arguments for the evaluation process.

Returns:

The evaluated model response.

Return type:

jnp.ndarray

loss: Callable[[Array, Array], Array]

The active loss function.

predictor: Callable[[Model, Frequency], Array]

The active predictor instance.

target: Array

The fixed target data.