TargetLoss
- class pmrf.evaluators.TargetLoss(predictor: Callable[[Model, Frequency], jnp.ndarray], target: Any, loss: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray])
Bases:
AbstractEvaluatorComputes a loss between a model prediction and some target.
- Parameters:
predictor (Callable[[pmrf.models.base.Model, pmrf.frequency.Frequency], jax.jaxlib._jax.Array]) – The predictor (e.g. another Evaluator) that extracts model features. Can be a function or a PyTree with optional parameters.
target (jax.jaxlib._jax.Array) – The fixed or ‘true’ target that the loss function should compare the prediction to.
loss (Callable[[jax.jaxlib._jax.Array, jax.jaxlib._jax.Array], jax.jaxlib._jax.Array]) – The loss function that takes (y_true, y_pred) and returns a loss metric. Can be a function or a PyTree with optional parameters. See
pmrf.lossesfor common losses.
- __call__(model: Model, frequency: Frequency, **kwargs) Array
Evaluate the model response over the specified frequency range.
- loss: Callable[[Array, Array], Array]
The active loss function.
- target: Array
The fixed target data.