partition (pmrf.partition)
- pmrf.partition(pytree: Any, filter_spec: Any) Tuple[Any, Any]
Splits a PyTree into two halves based on a filtering rule.
This method cleanly separates a single PyTree into a “dynamic” tree (containing all leaves that match the
filter_spec) and a “static” tree (containing everything else). The missing leaves in each half are replaced withNonestubs to preserve the overall structural layout.This is a lower-level, structural routine designed to bridge the gap between physical RF models and pure mathematical arrays. It is highly useful when interfacing with external machine learning libraries or strict JAX transformations that crash when encountering non-array physics context (like component names or boolean flags). Under the hood, this is a thin wrapper around Equinox’s partitioning utility.
- Parameters:
pytree (Any) – The original PyTree (e.g., an RF model) to be split.
filter_spec (Any) – A callable (e.g.,
lambda x: isinstance(x, jax.Array)) or a boolean PyTree mask (like the one generated bypmrf.batch_mask()) dictating which leaves belong in the dynamic half.
- Returns:
A tuple of
(dynamic_tree, static_tree).- Return type:
Tuple[Any, Any]
Examples
Split a model into differentiable arrays and static metadata:
>>> import jax >>> arrays, metadata = pmrf.partition(model, filter_spec=lambda x: isinstance(x, jax.Array))