Skip to content

Overview

Parax was originally designed to be used as a lower-level library for other frameworks to be built on top of. For example, ParamRF uses Parax as its backend. In this quick demonstration, we create a custom Parax variables that wraps another variable while adding a name, scale and metadata, to demonstrate how Parax can be used for this purpose.

1. Defining the Param class

Parax makes use of composable, nested wrappers for variable definitions. While we could make use of parax.Tagged and parax.Transformed variables to implement a name and scale in our framework, these wrappers are general. Instead, this is an ideal use-case to implement a custom parax.AbstractVariable class, providing the users of our framework with domain-specific properties and added type safety, while inheriting all of Parax's unwrapping and metadata features.

Below we create a Param class that wraps another arbitrary variable, allowing all of Parax's built-in variables (e.g. parax.Constrained, parax.Random etc.) while providing easy access to a .name and .scale property:

from typing import Any, Self

from equinox import field
import jax
import jax.numpy as jnp
import equinox as eqx
import parax as prx
from parax.annotation import AbstractAnnotated

class Param(prx.AbstractVariable, prx.AbstractWrappable[jax.Array], AbstractAnnotated[Any]):
    raw_value: prx.AbstractVariable
    scale: float = eqx.field(default=1.0, static=True)
    name: str | None = field(default=None, kw_only=True, static=True)
    metadata: Any = field(default=None, kw_only=True, static=True)

    @property
    def value(self) -> jax.Array:
        base_value = jnp.array(self.raw_value)
        if self.scale != 1.0:
            return base_value * self.scale
        return base_value

    def wrap(self, value: jax.Array) -> Self:
        new_raw_value = self.raw_value.wrap(value / self.scale)
        return eqx.tree_at(lambda x: x.raw_value, self, new_raw_value)

Note that we use Equinox to mark specific fields as static.

2. Creating a parameter

Now we can easily create a parameter instance:

from parax.constraints import Positive

my_param = Param(prx.Constrained(constraint=Positive(), value=2.0), scale=1e-3, name='my_param')

my_param.value
# Array(0.002, dtype=float32)

my_param.raw_value.raw_value
# Array(1.8545866, dtype=float32)

my_param.name
# 'my_param'

We can than perform tree mapping just as usual. For example, we could extract the bounds:

bounds = prx.bounds.tree_bounds(my_param)

bounds[0].raw_value
# Array(0., dtype=float32)

bounds[1].raw_value
# Array(inf, dtype=float32)

Refer to the other examples in the documentation for more specific applications.