batch_axes (pmrf.batch_axes)

pmrf.batch_axes(batched: Any, template: Any, *, is_leaf: Callable[[Any], bool] | None = None) Any

Generates an in_axes PyTree by comparing a batched model to a template.

This function creates an axis specification PyTree compatible with vectorization tools like equinox.filter_vmap() or jax.vmap(). It structurally compares the batched tree against an unbatched template, returning 0 for arrays that contain a leading batch dimension and None for static or unbatched leaves.

This is a lower-level routine. For most standard parameter evaluation workflows, users should prefer the high-level pmrf.sweep() function, which completely automates axis extraction and vectorization.

Parameters:
  • batched (Any) – The PyTree containing the batched (sampled or swept) parameters.

  • template (Any) – A single-sample PyTree used to define the structural layout and base shapes.

  • is_leaf (Callable[[Any], bool], optional) – An optional callable defining whether a node in the PyTree should be treated as a leaf.

Returns:

A PyTree of axis specifications (0 or None) matching the structure of the input trees.

Return type:

Any

Examples

Extract axes for custom, lower-level vectorization:

>>> import equinox as eqx
>>> axes = pmrf.batch_axes(sampled_model, template=best_model)
>>> batched_s = eqx.filter_vmap(lambda m: m.s(freq), in_axes=(axes,))(sampled_model)