Transforms
parax.transforms.AbstractTransform
Bases: Module
The base class for all callable transformations in Parax.
Transformations are applied to variables, typically within parax.Derived.
Because they inherit from equinox.Module, any JAX arrays stored as
attributes (e.g., learnable shifts or scales) can be tracked
by optimizers and serialization utilities.
__call__(x)
abstractmethod
Applies the transformation to the input array.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array to transform. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The transformed array. |
Source code in parax/transforms.py
33 34 35 36 37 38 39 40 41 42 43 44 | |
parax.transforms.Affine(shift=0.0, scale=1.0)
Bases: AbstractTransform
Applies a standard affine transformation: f(x) = x * scale + shift.
Useful for basic standardizations, unit conversions, or parameterizing linear adjustments in a scientific model.
Attributes:
| Name | Type | Description |
|---|---|---|
shift |
Float[Array, ...]
|
The translation applied to the input. |
scale |
Float[Array, ...]
|
The multiplier applied to the input. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shift
|
Union[float, Array]
|
The translation scalar or array. Defaults to 0.0. |
0.0
|
scale
|
Union[float, Array]
|
The multiplier scalar or array. Defaults to 1.0. |
1.0
|
Source code in parax/transforms.py
61 62 63 64 65 66 67 68 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The shifted and scaled array. |
Source code in parax/transforms.py
70 71 72 73 74 75 76 77 78 | |
parax.transforms.Shift(shift)
Bases: Affine
Applies a pure translation transformation: f(x) = x + shift.
A convenience subclass of parax.transforms.Affine strictly for
shifting values without scaling.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shift
|
Union[float, Array]
|
The translation scalar or array. |
required |
Source code in parax/transforms.py
89 90 91 92 93 94 | |
parax.transforms.Scale(scale)
Bases: Affine
Applies a pure scaling transformation: f(x) = x * scale.
A convenience subclass of parax.transforms.Affine strictly for
scaling values without translation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
scale
|
Union[float, Array]
|
The multiplier scalar or array. |
required |
Source code in parax/transforms.py
105 106 107 108 109 110 | |
parax.transforms.Clip(lower=-jnp.inf, upper=jnp.inf)
Bases: AbstractTransform
Clips (limits) the values in an array to a specified interval.
Corner Case Note: This is a mathematically destructive transformation
(gradients at the boundaries are zero). It should be used for hard
thresholding, not as a replacement for smooth parax.constraints.
Attributes:
| Name | Type | Description |
|---|---|---|
lower |
Float[Array, ...]
|
The minimum allowable value. |
upper |
Float[Array, ...]
|
The maximum allowable value. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lower
|
Union[float, Array]
|
The minimum allowable value. Defaults to -inf. |
-inf
|
upper
|
Union[float, Array]
|
The maximum allowable value. Defaults to inf. |
inf
|
Source code in parax/transforms.py
128 129 130 131 132 133 134 135 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The clipped array. |
Source code in parax/transforms.py
137 138 139 140 141 142 143 144 145 | |
parax.transforms.Reshape(shape)
Bases: AbstractTransform
Reshapes an array to a specified target shape.
Essential for bridging flat optimizer spaces with multi-dimensional scientific or spatial models (e.g., reshaping a 1D parameter array into a 2D spatial field).
Attributes:
| Name | Type | Description |
|---|---|---|
shape |
tuple[int, ...]
|
The target shape tuple. Can include |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
shape
|
tuple[int, ...]
|
The desired target shape. |
required |
Source code in parax/transforms.py
162 163 164 165 166 167 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The reshaped array. |
Source code in parax/transforms.py
169 170 171 172 173 174 175 176 177 | |
parax.transforms.Round(decimals=0)
Bases: AbstractTransform
Rounds array elements to a given number of decimals.
Useful for quantizing continuous parameters into discrete physical states (e.g., whole integer quantities).
Corner Case Note: This is a mathematically destructive, non-bijective transformation with zero gradients almost everywhere. It will stop standard autodiff dead in its tracks unless paired with a custom gradient estimator (like a straight-through estimator).
Attributes:
| Name | Type | Description |
|---|---|---|
decimals |
int
|
The number of decimal places to round to. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
decimals
|
int
|
Number of decimal places to round to. Defaults to 0 (rounds to nearest integer). |
0
|
Source code in parax/transforms.py
197 198 199 200 201 202 203 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The rounded array. |
Source code in parax/transforms.py
205 206 207 208 209 210 211 212 213 | |
parax.transforms.Softmax(axis=-1)
Bases: AbstractTransform
Applies the softmax function over a specified axis.
This is a classic non-bijective transformation mapping real numbers to a probability simplex (values sum to 1.0). Usefl in ML and categorical scientific modeling.
Attributes:
| Name | Type | Description |
|---|---|---|
axis |
int | tuple[int, ...] | None
|
The axis or axes along which the softmax should be computed. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis
|
Union[int, tuple[int, ...], None]
|
The axis or axes along which to compute the softmax. Defaults to -1 (the last axis). |
-1
|
Source code in parax/transforms.py
229 230 231 232 233 234 235 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
An array of the same shape where the specified axis forms a |
Array
|
probability distribution. |
Source code in parax/transforms.py
237 238 239 240 241 242 243 244 245 246 | |
parax.transforms.LogSoftmax(axis=-1)
Bases: AbstractTransform
Applies the log-softmax function over a specified axis.
Mathematically equivalent to log(softmax(x)), but implemented in JAX
to be vastly more numerically stable. Useful when modeling
log-probabilities or energy states to prevent underflow.
Attributes:
| Name | Type | Description |
|---|---|---|
axis |
int | tuple[int, ...] | None
|
The axis or axes along which the log-softmax should be computed. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis
|
Union[int, tuple[int, ...], None]
|
The axis or axes along which to compute the log-softmax. Defaults to -1 (the last axis). |
-1
|
Source code in parax/transforms.py
262 263 264 265 266 267 268 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
An array containing the log-probabilities along the specified axis. |
Source code in parax/transforms.py
270 271 272 273 274 275 276 277 278 | |
parax.transforms.Normalize(axis=None, epsilon=1e-05)
Bases: AbstractTransform
Normalizes an array to have zero mean and unit variance.
Corner Case Note: This transform computes the mean and variance
dynamically across the provided input x at evaluation time. It does
not store running statistics (like a BatchNorm layer).
Attributes:
| Name | Type | Description |
|---|---|---|
axis |
int | tuple[int, ...] | None
|
The axis or axes along which to compute the statistics. |
epsilon |
float
|
A small scalar added to the variance to prevent division by zero. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
axis
|
Union[int, tuple[int, ...], None]
|
The axis along which to normalize. If None, normalizes across the entire flattened array. Defaults to None. |
None
|
epsilon
|
float
|
Small scalar to prevent division by zero. Defaults to 1e-5. |
1e-05
|
Source code in parax/transforms.py
296 297 298 299 300 301 302 303 304 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The normalized array with mean 0 and variance 1 along the specified axis. |
Source code in parax/transforms.py
306 307 308 309 310 311 312 313 314 315 316 | |
parax.transforms.Chain(transforms)
Bases: AbstractTransform
Composes a sequence of transformations into a single transformation.
The transformations are applied in reverse order to match standard
mathematical function composition.
Mathematically, Chain([f, g, h])(x) is equivalent to f(g(h(x))).
Attributes:
| Name | Type | Description |
|---|---|---|
transforms |
tuple[AbstractTransform, ...]
|
A tuple containing the sequence of transformations. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transforms
|
list[AbstractTransform] | tuple[AbstractTransform, ...]
|
A sequence (list or tuple) of |
required |
Source code in parax/transforms.py
332 333 334 335 336 337 338 339 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The array mapped sequentially through all transformations in reverse order. |
Source code in parax/transforms.py
341 342 343 344 345 346 347 348 349 350 351 | |
parax.transforms.BijectorTransform
Bases: AbstractTransform
A transformation powered by a distreqx bijector.
Applies the forward pass of a given distreqx.bijectors.AbstractBijector.
This is the standard bridge for injecting complex, mathematically rigorous
bijections into derived variables.
Attributes:
| Name | Type | Description |
|---|---|---|
bijector |
AbstractBijector
|
The underlying distreqx bijector. |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The array mapped through the bijector's forward pass. |
Source code in parax/transforms.py
367 368 369 370 371 372 373 374 375 | |
parax.transforms.CustomTransform(fn)
Bases: AbstractTransform
An escape hatch for power users who need to apply an arbitrary function
as a transformation while strictly adhering to the AbstractTransform type hierarchy.
Corner Case Note: The underlying callable is marked as static=True.
This prevents JAX from attempting to flatten raw Python functions (like lambdas).
If your custom transformation requires learnable arrays or state, you should
subclass AbstractTransform directly instead of using this wrapper.
Attributes:
| Name | Type | Description |
|---|---|---|
_custom_fn |
Callable
|
The internal, user-defined callable. |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fn
|
Callable
|
The custom callable. Must accept a single array argument and return a transformed array. |
required |
Source code in parax/transforms.py
447 448 449 450 451 452 453 454 455 456 | |
__call__(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
The input array. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The array mapped through the custom function. |
Source code in parax/transforms.py
458 459 460 461 462 463 464 465 466 | |
parax.transforms.TreeTransform(transforms)
Bases: AbstractTransform
Represents a PyTree of transformations mapping over a PyTree of inputs.
Useful for applying heterogeneous transformations to complex nested structures
(like equinox.Module instances) simultaneously.
Attributes:
| Name | Type | Description |
|---|---|---|
tree |
PyTree[AbstractTransform]
|
The PyTree containing |
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
transforms
|
PyTree[AbstractTransform]
|
A PyTree containing |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the provided PyTree contains no transform leaves. |
Source code in parax/transforms.py
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 | |
__call__(x)
Maps each leaf transform over the corresponding node in the input PyTree.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
PyTree[Array]
|
The input PyTree of arrays. Must have a matching tree prefix
to the |
required |
Returns:
| Type | Description |
|---|---|
PyTree[Array]
|
A new PyTree containing the transformed arrays. |
Source code in parax/transforms.py
411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 | |