Variables
parax.AbstractVariable
Bases: AbstractUnwrappable[Array]
The abstract interface for all model variables.
Derive from this class and override value to implement
custom variable unwrapping behaviour.
All parameters in Parax, such as parax.Param,
parax.Constrained etc., derive from this class.
Corner Case Note (Math & Dunders): Because this class implements the
__jax_array__ protocol and all standard math dunder methods, variables
can be used directly in JAX expressions without explicitly calling
unwrap(). However, applying any math operation (e.g., var + 1) instantly
evaluates the value and returns a standard jax.Array, stripping away
the metadata and constraint wrappers.
value
abstractmethod
property
Returns the underlying, fully computed value of the variable.
parax.Param = AbstractVariable | Inexact[Array, '...']
module-attribute
A type alias representing a JAX parameter.
This includes any Parax variables (like Tagged, Constrained, Derived)
as well as standard JAX inexact arrays.
parax.Tagged
Bases: AbstractVariable, AbstractAnnotated[dict]
A variable with dictionary metadata.
Represents a simple, trainable variable
with a single underlying raw_value and metadata.
Attributes:
| Name | Type | Description |
|---|---|---|
raw_value |
Param
|
The raw value used by optimizers and samplers. |
metadata |
dict
|
Additional arbitrary metadata. |
parax.Fixed
Bases: AbstractVariable, AbstractConstant[AbstractVariable]
A fixed variable.
Implements AbstractConstant for structural filtering during partitioning.
Corner Case Note: This class implements __getattr__ to forward all
unrecognized attribute lookups to the underlying wrapped variable. This means
a Fixed(Constrained(...)) object will still safely expose .constraint,
.bounds, and .metadata to the user as if it weren't wrapped at all.
Attributes:
| Name | Type | Description |
|---|---|---|
raw_value |
Param
|
The underlying variable that is being fixed. |
parax.Derived
Bases: AbstractVariable
A derived variable.
The parameter's value is dynamically derived via an arbitrary callable.
This is ideal for one-way transformations, projections, or normalizations
where a strict bijector (with an inverse) is not required or mathematically
possible (e.g., applying jax.nn.softmax to raw logits).
Attributes:
| Name | Type | Description |
|---|---|---|
raw_value |
Param
|
The raw value used by optimizers and samplers. |
fn |
Callable
|
The callable used to transform the raw value. |
value
property
The derived value.
Returns the raw state transformed by the derivation function.
parax.Constrained(constraint=None, value=None, *, raw_value=None)
Bases: AbstractVariable, AbstractBounded[Array]
A constrained variable.
The constraint is specified via a parax.AbstractConstraint.
The constraint is automatically applied as a bijection mapping during
evaluation. Implements the parax.bounded.AbstractBounded interface
for integration with bounded optimizers.
Attributes:
| Name | Type | Description |
|---|---|---|
raw_value |
Param
|
The raw, unconstrained value mapping to the real number line. |
constraint |
AbstractConstraint
|
The parameter constraint defining bounds and bijector mappings. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
constraint
|
AbstractConstraint | None
|
A Parax constraint. If None, defaults to |
None
|
value
|
Array | None
|
The desired output (constrained) value. If provided, the internal
|
None
|
raw_value
|
Param | None
|
The unconstrained optimizer-space value. Mutually exclusive with |
None
|
Source code in parax/variables.py
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 | |
parax.tagged(default=dataclasses.MISSING, metadata=None)
Specifies a dataclass field for a Parax Tagged variable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
default
|
Param
|
The default value. If omitted, this field becomes required by the user during instantiation. |
MISSING
|
metadata
|
dict | None
|
Additional static metadata to store. |
None
|
Returns:
| Type | Description |
|---|---|
Any
|
An |
Source code in parax/variables.py
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 | |
parax.derived(fn=lambda x: x, default=dataclasses.MISSING)
Specifies a dataclass field for a Parax Derived variable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
The callable used to transform the raw value. |
lambda x: x
|
default
|
Param
|
The default raw value. If omitted, this field becomes required. |
MISSING
|
Returns:
| Type | Description |
|---|---|
Any
|
An |
Source code in parax/variables.py
370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 | |
parax.constrained(constraint=None, default=dataclasses.MISSING)
Specifies a dataclass field for a Parax parax.Constrained variable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
constraint
|
AbstractConstraint | None
|
The abstract constraint defining base bounds and mappings. |
None
|
default
|
Param
|
The default constrained value. If omitted, this field becomes required. |
MISSING
|
Returns:
| Type | Description |
|---|---|
Any
|
An |
Source code in parax/variables.py
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 | |