batch_mask (pmrf.batch_mask)
- pmrf.batch_mask(batched: Any, template: Any, *, is_leaf: Callable[[Any], bool] | None = None) Any
Generates a boolean PyTree mask identifying which leaves have a batch dimension.
This function compares a batched PyTree against an unbatched template structure to determine which specific arrays contain a leading batch axis. It returns
Truefor batched arrays andFalsefor unbatched arrays, static variables, strings, or constants.This is generally considered a lower-level routine. While users typically rely on high-level wrappers like
pmrf.sweep()to handle batched evaluation automatically, this method is extremely useful for advanced workflows (such as machine learning or custom JAX vectorization) where you need strict, explicit control over isolating training data from physics parameters.- Parameters:
batched (Any) – The PyTree containing the batched (sampled or swept) parameters.
template (Any) – An unbatched PyTree used as a structural reference. This determines the base dimensionality of the physical parameters.
is_leaf (Callable[[Any], bool], optional) – An optional callable defining whether a node in the PyTree should be treated as a leaf.
- Returns:
A PyTree of booleans matching the structure of the input trees.
- Return type:
Any
Examples
Isolate batched data to prepare for normalizing flow training:
>>> mask = pmrf.batch_mask(sampled_model, template=best_model) >>> dynamic_samples, static_physics = pmrf.partition(sampled_model, mask)