Skip to content

Filters

parax.remove(pytree, condition, *, stop_if=None)

Removes nodes from a PyTree that match a given condition.

Replaces matching nodes with None. Halts traversal at matching nodes, as well as any nodes matching stop_if.

Parameters:

Name Type Description Default
pytree PyTree

The input PyTree to filter.

required
condition Callable[[Any], bool]

A function that evaluates to True for nodes that should be removed.

required
stop_if Callable[[Any], bool]

A function that evaluates to True for leaf nodes in addition to condition.

None

Returns:

Name Type Description
Any Any

A copy of the PyTree with the matched nodes replaced by None.

Source code in parax/filters.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def remove(pytree: PyTree, condition: Callable[[Any], bool], *, stop_if: Callable[[Any], bool] = None) -> Any:
    """Removes nodes from a PyTree that match a given condition.

    Replaces matching nodes with None. Halts traversal at matching nodes,
    as well as any nodes matching `stop_if`.

    Args:
        pytree (PyTree): The input PyTree to filter.
        condition (Callable[[Any], bool]): A function that evaluates to True 
            for nodes that should be removed.
        stop_if (Callable[[Any], bool]): A function that evaluates to True 
            for leaf nodes in addition to `condition`.

    Returns:
        Any: A copy of the PyTree with the matched nodes replaced by None.
    """
    if stop_if is None:
        stop_if = lambda _: False

    return eqx.filter(
        pytree, 
        filter_spec=condition, 
        is_leaf=lambda x: stop_if(x) or condition(x), 
        inverse=True
    )

parax.is_constant(x)

Returns True if x is an instance of parax.AbstractConstant.

Useful as is_leaf when partitioning a model to freeze standard parameters.

Source code in parax/constants.py
35
36
37
38
39
40
41
def is_constant(x: Any) -> TypeGuard[AbstractConstant]:
    """
    Returns True if `x` is an instance of `parax.AbstractConstant`.

    Useful as `is_leaf` when partitioning a model to freeze standard parameters.
    """
    return isinstance(x, AbstractConstant)

parax.is_annotated(x)

Returns True if x is an instance of parax.AbstractAnnotated (i.e. has metadata).

Source code in parax/annotation.py
28
29
30
31
32
33
def is_annotated(x: Any) -> TypeGuard[AbstractAnnotated]:
    """
    Returns True if `x` is an instance of `parax.AbstractAnnotated`
    (i.e. has metadata).
    """
    return isinstance(x, AbstractAnnotated)

parax.is_variable(x)

Returns True if x is an instance of parax.AbstractVariable.

Source code in parax/variables.py
120
121
122
123
124
def is_variable(x: Any) -> TypeGuard[AbstractVariable]:
    """
    Returns True if `x` is an instance of `parax.AbstractVariable`.
    """
    return isinstance(x, AbstractVariable)

parax.is_param(x)

Returns True if x is an instance of parax.AbstractVariable or returns True for eqx.is_inexact_array.

Source code in parax/variables.py
127
128
129
130
131
132
def is_param(x: Any) -> bool:
    """
    Returns True if `x` is an instance of `parax.AbstractVariable`
    or returns True for `eqx.is_inexact_array`.
    """
    return isinstance(x, AbstractVariable) or eqx.is_inexact_array(x)

parax.is_unwrappable(x)

Checks if a given object is an unwrappable node.

Parameters:

Name Type Description Default
x Any

The object to check.

required

Returns:

Type Description
TypeGuard[AbstractUnwrappable]

True if x is an instance of AbstractUnwrappable, False otherwise.

Source code in parax/wrappers.py
47
48
49
50
51
52
53
54
55
56
def is_unwrappable(x: Any) -> TypeGuard[AbstractUnwrappable]:
    """Checks if a given object is an unwrappable node.

    Args:
        x: The object to check.

    Returns:
        True if `x` is an instance of `AbstractUnwrappable`, False otherwise.
    """
    return isinstance(x, AbstractUnwrappable)

parax.is_wrappable(x)

Checks if a given object is a wrappable node.

Parameters:

Name Type Description Default
x Any

The object to check.

required

Returns:

Type Description
TypeGuard[AbstractWrappable]

True if x is an instance of AbstractWrappable, False otherwise.

Source code in parax/wrappers.py
209
210
211
212
213
214
215
216
217
218
def is_wrappable(x: Any) -> TypeGuard[AbstractWrappable]:
    """Checks if a given object is a wrappable node.

    Args:
        x: The object to check.

    Returns:
        True if `x` is an instance of `AbstractWrappable`, False otherwise.
    """
    return isinstance(x, AbstractWrappable)

parax.is_bounded(x)

Returns True if x is an instance of parax.AbstractBounded.

Source code in parax/bounds.py
27
28
29
30
31
def is_bounded(x: Any) -> TypeGuard[AbstractBounded]:
    """
    Returns True if `x` is an instance of `parax.AbstractBounded`.
    """
    return isinstance(x, AbstractBounded)

parax.is_probabilistic(x)

Returns True if x is an instance of parax.AbstractProbabilistic.

Source code in parax/probability.py
38
39
40
41
42
def is_probabilistic(x: Any) -> TypeGuard[AbstractProbabilistic]:
    """
    Returns True if `x` is an instance of `parax.AbstractProbabilistic`.
    """
    return isinstance(x, AbstractProbabilistic)

parax.is_constraint(x)

Returns True if x is an instance of parax.AbstractConstraint.

Source code in parax/constraints.py
93
94
95
96
97
def is_constraint(x: Any) -> TypeGuard[AbstractConstraint]:
    """
    Returns True if `x` is an instance of `parax.AbstractConstraint`.
    """
    return isinstance(x, AbstractConstraint)    

parax.is_distribution(x)

Returns True if x is an instance of distreqx.AbstractDistribution.

Source code in parax/filters.py
14
15
16
17
18
def is_distribution(x: Any) -> TypeGuard[AbstractDistribution]:
    """
    Returns True if `x` is an instance of `distreqx.AbstractDistribution`.
    """
    return isinstance(x, AbstractDistribution)

parax.is_bijector(x)

Returns True if x is an instance of distreqx.AbstractBijector.

Source code in parax/filters.py
21
22
23
24
25
def is_bijector(x: Any) -> TypeGuard[AbstractBijector]:
    """
    Returns True if `x` is an instance of `distreqx.AbstractBijector`.
    """
    return isinstance(x, AbstractBijector)