Skip to content

Converters

parax.as_free(value)

Returns a freed version of value by stripping any constant wrappers.

If value implements AbstractConstant, this calls value.as_free(). Otherwise, it acts as a safe no-op and returns value unchanged. This makes it safe to use directly within a jax.tree_map over mixed PyTrees.

Parameters:

Name Type Description Default
value Union[AbstractConstant[T], T]

An arbitrary value, potentially wrapped in an AbstractConstant.

required

Returns:

Type Description
T

The freed parameter, or the original value if it was not fixed.

Source code in parax/converters.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def as_free(value: Union[AbstractConstant[T], T]) -> T:
    """
    Returns a freed version of `value` by stripping any constant wrappers.

    If `value` implements `AbstractConstant`, this calls `value.as_free()`.
    Otherwise, it acts as a safe no-op and returns `value` unchanged. This
    makes it safe to use directly within a `jax.tree_map` over mixed PyTrees.

    Args:
        value: An arbitrary value, potentially wrapped in an `AbstractConstant`.

    Returns:
        The freed parameter, or the original value if it was not fixed.
    """    
    if isinstance(value, AbstractConstant):
        return value.as_free()
    return value

parax.as_fixed(value)

Returns value as a parax.Fixed variable, wrapping it if necessary.

Parameters:

Name Type Description Default
value Param

An arbitrary variable or array-like object.

required

Returns:

Type Description
Fixed

A fixed version of the variable.

Source code in parax/converters.py
62
63
64
65
66
67
68
69
70
71
72
73
74
def as_fixed(value: Param) -> Fixed:
    """
    Returns `value` as a `parax.Fixed` variable, wrapping it if necessary.

    Args:
        value: An arbitrary variable or array-like object.

    Returns:
        A fixed version of the variable.
    """    
    if isinstance(value, Fixed):
        return value
    return Fixed(value)

parax.as_frozen(pytree)

Returns pytree wrapped in a parax.Frozen module, creating one if needed.

Parameters:

Name Type Description Default
pytree Union[T | Frozen[T]]

An arbitrary PyTree.

required

Returns:

Type Description
T

A frozen version of the PyTree. If it is already frozen, returns it directly.

Source code in parax/converters.py
31
32
33
34
35
36
37
38
39
40
41
42
43
def as_frozen(pytree: Union[T | Frozen[T]]) -> T:
    """
    Returns `pytree` wrapped in a `parax.Frozen` module, creating one if needed.

    Args:
        pytree: An arbitrary PyTree.

    Returns:
        A frozen version of the PyTree. If it is already frozen, returns it directly.
    """    
    if isinstance(pytree, Frozen):
        return pytree
    return Frozen(pytree)

parax.as_param(value)

Returns value as a parax.Param, wrapping it if necessary.

Parameters:

Name Type Description Default
value Any

An arbitrary value or array.

required

Returns:

Type Description
Any

The instantiated parameter.

Source code in parax/converters.py
46
47
48
49
50
51
52
53
54
55
56
57
58
def as_param(value: Any) -> Any:
    """
    Returns `value` as a `parax.Param`, wrapping it if necessary.

    Args:
        value: An arbitrary value or array.

    Returns:
        The instantiated parameter.
    """    
    if is_param(value):
        return value
    return jnp.asarray(value)