batch_axes (pmrf.batch_axes)
- pmrf.batch_axes(batched: Any, template: Any, *, is_leaf: Callable[[Any], bool] | None = None) Any
Generates an in_axes PyTree by comparing a batched model to a template.
This function creates an axis specification PyTree compatible with vectorization tools like
equinox.filter_vmap()orjax.vmap(). It structurally compares the batched tree against an unbatched template, returning0for arrays that contain a leading batch dimension andNonefor static or unbatched leaves.This is a lower-level routine. For most standard parameter evaluation workflows, users should prefer the high-level
pmrf.sweep()function, which completely automates axis extraction and vectorization.- Parameters:
batched (Any) – The PyTree containing the batched (sampled or swept) parameters.
template (Any) – A single-sample PyTree used to define the structural layout and base shapes.
is_leaf (Callable[[Any], bool], optional) – An optional callable defining whether a node in the PyTree should be treated as a leaf.
- Returns:
A PyTree of axis specifications (
0orNone) matching the structure of the input trees.- Return type:
Any
Examples
Extract axes for custom, lower-level vectorization:
>>> import equinox as eqx >>> axes = pmrf.batch_axes(sampled_model, template=best_model) >>> batched_s = eqx.filter_vmap(lambda m: m.s(freq), in_axes=(axes,))(sampled_model)