partition (pmrf.partition)

pmrf.partition(pytree: Any, filter_spec: Any) Tuple[Any, Any]

Splits a PyTree into two halves based on a filtering rule.

This method cleanly separates a single PyTree into a “dynamic” tree (containing all leaves that match the filter_spec) and a “static” tree (containing everything else). The missing leaves in each half are replaced with None stubs to preserve the overall structural layout.

This is a lower-level, structural routine designed to bridge the gap between physical RF models and pure mathematical arrays. It is highly useful when interfacing with external machine learning libraries or strict JAX transformations that crash when encountering non-array physics context (like component names or boolean flags). Under the hood, this is a thin wrapper around Equinox’s partitioning utility.

Parameters:
  • pytree (Any) – The original PyTree (e.g., an RF model) to be split.

  • filter_spec (Any) – A callable (e.g., lambda x: isinstance(x, jax.Array)) or a boolean PyTree mask (like the one generated by pmrf.batch_mask()) dictating which leaves belong in the dynamic half.

Returns:

A tuple of (dynamic_tree, static_tree).

Return type:

Tuple[Any, Any]

Examples

Split a model into differentiable arrays and static metadata:

>>> import jax
>>> arrays, metadata = pmrf.partition(model, filter_spec=lambda x: isinstance(x, jax.Array))