Skip to content

Variables

parax.AbstractVariable

Bases: AbstractUnwrappable[Array]

The abstract interface for all model variables.

Derive from this class and override value to implement custom variable unwrapping behaviour.

All parameters in Parax, such as parax.Param, parax.Constrained etc., derive from this class.

Corner Case Note (Math & Dunders): Because this class implements the __jax_array__ protocol and all standard math dunder methods, variables can be used directly in JAX expressions without explicitly calling unwrap(). However, applying any math operation (e.g., var + 1) instantly evaluates the value and returns a standard jax.Array, stripping away the metadata and constraint wrappers.

value abstractmethod property

Returns the underlying, fully computed value of the variable.

parax.Param = AbstractVariable | Inexact[Array, '...'] module-attribute

A type alias representing a JAX parameter.

This includes any Parax variables (like Tagged, Constrained, Derived) as well as standard JAX inexact arrays.

parax.Tagged

Bases: AbstractVariable, AbstractAnnotated[dict]

A variable with dictionary metadata.

Represents a simple, trainable variable with a single underlying raw_value and metadata.

Attributes:

Name Type Description
raw_value Param

The raw value used by optimizers and samplers.

metadata dict

Additional arbitrary metadata.

parax.Fixed

Bases: AbstractVariable, AbstractConstant[AbstractVariable]

A fixed variable.

Implements AbstractConstant for structural filtering during partitioning.

Corner Case Note: This class implements __getattr__ to forward all unrecognized attribute lookups to the underlying wrapped variable. This means a Fixed(Constrained(...)) object will still safely expose .constraint, .bounds, and .metadata to the user as if it weren't wrapped at all.

Attributes:

Name Type Description
raw_value Param

The underlying variable that is being fixed.

parax.Derived

Bases: AbstractVariable

A derived variable.

The parameter's value is dynamically derived via an arbitrary callable.

This is ideal for one-way transformations, projections, or normalizations where a strict bijector (with an inverse) is not required or mathematically possible (e.g., applying jax.nn.softmax to raw logits).

Attributes:

Name Type Description
raw_value Param

The raw value used by optimizers and samplers.

fn Callable

The callable used to transform the raw value.

value property

The derived value.

Returns the raw state transformed by the derivation function.

parax.Constrained(constraint=None, value=None, *, raw_value=None)

Bases: AbstractVariable, AbstractBounded[Array]

A constrained variable.

The constraint is specified via a parax.AbstractConstraint.

The constraint is automatically applied as a bijection mapping during evaluation. Implements the parax.bounded.AbstractBounded interface for integration with bounded optimizers.

Attributes:

Name Type Description
raw_value Param

The raw, unconstrained value mapping to the real number line.

constraint AbstractConstraint

The parameter constraint defining bounds and bijector mappings.

Parameters:

Name Type Description Default
constraint AbstractConstraint | None

A Parax constraint. If None, defaults to parax.RealLine (unconstrained).

None
value Array | None

The desired output (constrained) value. If provided, the internal raw_value is computed dynamically via the constraint's inverse bijector. Mutually exclusive with raw_value.

None
raw_value Param | None

The unconstrained optimizer-space value. Mutually exclusive with value.

None
Source code in parax/variables.py
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def __init__(
    self,
    constraint: AbstractConstraint | None = None,
    value: Array | None = None,
    *,
    raw_value: Param | None = None,
):
    """
    Args:
        constraint: A Parax constraint. If None, defaults to `parax.RealLine` (unconstrained).
        value: The desired output (constrained) value. If provided, the internal 
            `raw_value` is computed dynamically via the constraint's inverse bijector. 
            Mutually exclusive with `raw_value`.
        raw_value: The unconstrained optimizer-space value. Mutually exclusive with `value`.
    """
    # Error checking
    if value is None and raw_value is None:
        raise ValueError("Must provide either `value` or `raw_value`.")
    if value is not None and raw_value is not None:
        raise ValueError("Cannot provide both `value` and `raw_value`.")

    # Array standardization
    if raw_value is not None:
        raw_value = _as_param(raw_value)
        shape = raw_value.shape
    else:
        value = jnp.asarray(value)
        shape = value.shape

    # Constraint and distribution standardization
    if constraint is None:
        constraint = RealLine(shape=shape)

    # Raw value standardization
    if value is not None:
        raw_value = constraint.bijector.inverse(value)

    self.constraint = constraint
    self.raw_value = raw_value

parax.tagged(default=dataclasses.MISSING, metadata=None)

Specifies a dataclass field for a Parax Tagged variable.

Parameters:

Name Type Description Default
default Param

The default value. If omitted, this field becomes required by the user during instantiation.

MISSING
metadata dict | None

Additional static metadata to store.

None

Returns:

Type Description
Any

An equinox.field properly configured for the field type.

Source code in parax/variables.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def tagged(
    default: Param = dataclasses.MISSING,
    metadata: dict | None = None,
) -> Any:
    """
    Specifies a dataclass field for a Parax `Tagged` variable.

    Args:
        default: The default value. If omitted, this field becomes required 
            by the user during instantiation.
        metadata: Additional static metadata to store.

    Returns:
        An `equinox.field` properly configured for the field type.
    """
    if metadata is None: metadata = {}

    def converter(x: Any) -> AbstractVariable:
        if isinstance(x, AbstractVariable):
            return x

        return Tagged(raw_value=x, metadata=metadata)

    field_kwargs = {"converter": converter}
    if default is not dataclasses.MISSING:
        field_kwargs["default"] = default

    return eqx.field(**field_kwargs)

parax.derived(fn=lambda x: x, default=dataclasses.MISSING)

Specifies a dataclass field for a Parax Derived variable.

Parameters:

Name Type Description Default
fn Callable

The callable used to transform the raw value.

lambda x: x
default Param

The default raw value. If omitted, this field becomes required.

MISSING

Returns:

Type Description
Any

An equinox.field properly configured for the field type.

Source code in parax/variables.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def derived(
    fn: Callable = lambda x: x,
    default: Param = dataclasses.MISSING,
) -> Any:
    """
    Specifies a dataclass field for a Parax `Derived` variable.

    Args:
        fn: The callable used to transform the raw value.
        default: The default raw value. If omitted, this field becomes required.

    Returns:
        An `equinox.field` properly configured for the field type.
    """
    def converter(x: Any) -> AbstractVariable:
        if isinstance(x, AbstractVariable):
            return x
        return Derived(fn=fn, raw_value=x)

    field_kwargs = {"converter": converter}
    if default is not dataclasses.MISSING:
        field_kwargs["default"] = default

    return eqx.field(**field_kwargs)

parax.constrained(constraint=None, default=dataclasses.MISSING)

Specifies a dataclass field for a Parax parax.Constrained variable.

Parameters:

Name Type Description Default
constraint AbstractConstraint | None

The abstract constraint defining base bounds and mappings.

None
default Param

The default constrained value. If omitted, this field becomes required.

MISSING

Returns:

Type Description
Any

An equinox.field properly configured for the field type.

Source code in parax/variables.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def constrained(
    constraint: AbstractConstraint | None = None,
    default: Param = dataclasses.MISSING,
) -> Any:
    """
    Specifies a dataclass field for a Parax `parax.Constrained` variable.

    Args:
        constraint: The abstract constraint defining base bounds and mappings.
        default: The default constrained value. If omitted, this field becomes required.

    Returns:
        An `equinox.field` properly configured for the field type.
    """
    def converter(x: Any) -> AbstractVariable:
        if isinstance(x, AbstractVariable):
            return x
        return Constrained(constraint=constraint, value=x)

    field_kwargs = {"converter": converter}
    if default is not dataclasses.MISSING:
        field_kwargs["default"] = default

    return eqx.field(**field_kwargs)