Converters
parax.as_free(value)
Returns a freed version of value by stripping any constant wrappers.
If value implements AbstractConstant, this calls value.as_free().
Otherwise, it acts as a safe no-op and returns value unchanged. This
makes it safe to use directly within a jax.tree_map over mixed PyTrees.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Union[AbstractConstant[T], T]
|
An arbitrary value, potentially wrapped in an |
required |
Returns:
| Type | Description |
|---|---|
T
|
The freed parameter, or the original value if it was not fixed. |
Source code in parax/converters.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | |
parax.as_fixed(value)
Returns value as a parax.Fixed variable, wrapping it if necessary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Param
|
An arbitrary variable or array-like object. |
required |
Returns:
| Type | Description |
|---|---|
Fixed
|
A fixed version of the variable. |
Source code in parax/converters.py
62 63 64 65 66 67 68 69 70 71 72 73 74 | |
parax.as_frozen(pytree)
Returns pytree wrapped in a parax.Frozen module, creating one if needed.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pytree
|
Union[T | Frozen[T]]
|
An arbitrary PyTree. |
required |
Returns:
| Type | Description |
|---|---|
T
|
A frozen version of the PyTree. If it is already frozen, returns it directly. |
Source code in parax/converters.py
31 32 33 34 35 36 37 38 39 40 41 42 43 | |
parax.as_param(value)
Returns value as a parax.Param, wrapping it if necessary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Any
|
An arbitrary value or array. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
The instantiated parameter. |
Source code in parax/converters.py
46 47 48 49 50 51 52 53 54 55 56 57 58 | |