Skip to content

Transforms

parax.transforms.AbstractTransform

Bases: Module

The base class for all callable transformations in Parax.

Transformations are applied to variables, typically within parax.Derived. Because they inherit from equinox.Module, any JAX arrays stored as attributes (e.g., learnable shifts or scales) can be tracked by optimizers and serialization utilities.

__call__(x) abstractmethod

Applies the transformation to the input array.

Parameters:

Name Type Description Default
x Array

The input array to transform.

required

Returns:

Type Description
Array

The transformed array.

Source code in parax/transforms.py
33
34
35
36
37
38
39
40
41
42
43
44
@abstractmethod
def __call__(self, x: Array) -> Array:
    """
    Applies the transformation to the input array.

    Args:
        x: The input array to transform.

    Returns:
        The transformed array.
    """
    pass

parax.transforms.Affine(shift=0.0, scale=1.0)

Bases: AbstractTransform

Applies a standard affine transformation: f(x) = x * scale + shift.

Useful for basic standardizations, unit conversions, or parameterizing linear adjustments in a scientific model.

Attributes:

Name Type Description
shift Float[Array, ...]

The translation applied to the input.

scale Float[Array, ...]

The multiplier applied to the input.

Parameters:

Name Type Description Default
shift Union[float, Array]

The translation scalar or array. Defaults to 0.0.

0.0
scale Union[float, Array]

The multiplier scalar or array. Defaults to 1.0.

1.0
Source code in parax/transforms.py
61
62
63
64
65
66
67
68
def __init__(self, shift: Union[float, Array] = 0.0, scale: Union[float, Array] = 1.0):
    """
    Args:
        shift: The translation scalar or array. Defaults to 0.0.
        scale: The multiplier scalar or array. Defaults to 1.0.
    """
    self.shift = jnp.asarray(shift, dtype=float)
    self.scale = jnp.asarray(scale, dtype=float)

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

The shifted and scaled array.

Source code in parax/transforms.py
70
71
72
73
74
75
76
77
78
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        The shifted and scaled array.
    """
    return x * self.scale + self.shift

parax.transforms.Shift(shift)

Bases: Affine

Applies a pure translation transformation: f(x) = x + shift.

A convenience subclass of parax.transforms.Affine strictly for shifting values without scaling.

Parameters:

Name Type Description Default
shift Union[float, Array]

The translation scalar or array.

required
Source code in parax/transforms.py
89
90
91
92
93
94
def __init__(self, shift: Union[float, Array]):
    """
    Args:
        shift: The translation scalar or array.
    """
    super().__init__(shift=shift, scale=1.0)

parax.transforms.Scale(scale)

Bases: Affine

Applies a pure scaling transformation: f(x) = x * scale.

A convenience subclass of parax.transforms.Affine strictly for scaling values without translation.

Parameters:

Name Type Description Default
scale Union[float, Array]

The multiplier scalar or array.

required
Source code in parax/transforms.py
105
106
107
108
109
110
def __init__(self, scale: Union[float, Array]):
    """
    Args:
        scale: The multiplier scalar or array.
    """
    super().__init__(shift=0.0, scale=scale)

parax.transforms.Clip(lower=-jnp.inf, upper=jnp.inf)

Bases: AbstractTransform

Clips (limits) the values in an array to a specified interval.

Corner Case Note: This is a mathematically destructive transformation (gradients at the boundaries are zero). It should be used for hard thresholding, not as a replacement for smooth parax.constraints.

Attributes:

Name Type Description
lower Float[Array, ...]

The minimum allowable value.

upper Float[Array, ...]

The maximum allowable value.

Parameters:

Name Type Description Default
lower Union[float, Array]

The minimum allowable value. Defaults to -inf.

-inf
upper Union[float, Array]

The maximum allowable value. Defaults to inf.

inf
Source code in parax/transforms.py
128
129
130
131
132
133
134
135
def __init__(self, lower: Union[float, Array] = -jnp.inf, upper: Union[float, Array] = jnp.inf):
    """
    Args:
        lower: The minimum allowable value. Defaults to -inf.
        upper: The maximum allowable value. Defaults to inf.
    """
    self.lower = jnp.asarray(lower, dtype=float)
    self.upper = jnp.asarray(upper, dtype=float)

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

The clipped array.

Source code in parax/transforms.py
137
138
139
140
141
142
143
144
145
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        The clipped array.
    """
    return jnp.clip(x, min=self.lower, max=self.upper)

parax.transforms.Reshape(shape)

Bases: AbstractTransform

Reshapes an array to a specified target shape.

Essential for bridging flat optimizer spaces with multi-dimensional scientific or spatial models (e.g., reshaping a 1D parameter array into a 2D spatial field).

Attributes:

Name Type Description
shape tuple[int, ...]

The target shape tuple. Can include -1 to infer the size of one dimension automatically.

Parameters:

Name Type Description Default
shape tuple[int, ...]

The desired target shape.

required
Source code in parax/transforms.py
162
163
164
165
166
167
def __init__(self, shape: tuple[int, ...]):
    """
    Args:
        shape: The desired target shape.
    """
    self.shape = shape

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

The reshaped array.

Source code in parax/transforms.py
169
170
171
172
173
174
175
176
177
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        The reshaped array.
    """
    return jnp.reshape(x, self.shape)

parax.transforms.Round(decimals=0)

Bases: AbstractTransform

Rounds array elements to a given number of decimals.

Useful for quantizing continuous parameters into discrete physical states (e.g., whole integer quantities).

Corner Case Note: This is a mathematically destructive, non-bijective transformation with zero gradients almost everywhere. It will stop standard autodiff dead in its tracks unless paired with a custom gradient estimator (like a straight-through estimator).

Attributes:

Name Type Description
decimals int

The number of decimal places to round to.

Parameters:

Name Type Description Default
decimals int

Number of decimal places to round to. Defaults to 0 (rounds to nearest integer).

0
Source code in parax/transforms.py
197
198
199
200
201
202
203
def __init__(self, decimals: int = 0):
    """
    Args:
        decimals: Number of decimal places to round to. Defaults to 0 
            (rounds to nearest integer).
    """
    self.decimals = decimals

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

The rounded array.

Source code in parax/transforms.py
205
206
207
208
209
210
211
212
213
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        The rounded array.
    """
    return jnp.round(x, decimals=self.decimals)

parax.transforms.Softmax(axis=-1)

Bases: AbstractTransform

Applies the softmax function over a specified axis.

This is a classic non-bijective transformation mapping real numbers to a probability simplex (values sum to 1.0). Usefl in ML and categorical scientific modeling.

Attributes:

Name Type Description
axis int | tuple[int, ...] | None

The axis or axes along which the softmax should be computed.

Parameters:

Name Type Description Default
axis Union[int, tuple[int, ...], None]

The axis or axes along which to compute the softmax. Defaults to -1 (the last axis).

-1
Source code in parax/transforms.py
229
230
231
232
233
234
235
def __init__(self, axis: Union[int, tuple[int, ...], None] = -1):
    """
    Args:
        axis: The axis or axes along which to compute the softmax. 
            Defaults to -1 (the last axis).
    """
    self.axis = axis

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

An array of the same shape where the specified axis forms a

Array

probability distribution.

Source code in parax/transforms.py
237
238
239
240
241
242
243
244
245
246
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        An array of the same shape where the specified axis forms a 
        probability distribution.
    """
    return jax.nn.softmax(x, axis=self.axis)

parax.transforms.LogSoftmax(axis=-1)

Bases: AbstractTransform

Applies the log-softmax function over a specified axis.

Mathematically equivalent to log(softmax(x)), but implemented in JAX to be vastly more numerically stable. Useful when modeling log-probabilities or energy states to prevent underflow.

Attributes:

Name Type Description
axis int | tuple[int, ...] | None

The axis or axes along which the log-softmax should be computed.

Parameters:

Name Type Description Default
axis Union[int, tuple[int, ...], None]

The axis or axes along which to compute the log-softmax. Defaults to -1 (the last axis).

-1
Source code in parax/transforms.py
262
263
264
265
266
267
268
def __init__(self, axis: Union[int, tuple[int, ...], None] = -1):
    """
    Args:
        axis: The axis or axes along which to compute the log-softmax. 
            Defaults to -1 (the last axis).
    """
    self.axis = axis

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

An array containing the log-probabilities along the specified axis.

Source code in parax/transforms.py
270
271
272
273
274
275
276
277
278
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        An array containing the log-probabilities along the specified axis.
    """
    return jax.nn.log_softmax(x, axis=self.axis)

parax.transforms.Normalize(axis=None, epsilon=1e-05)

Bases: AbstractTransform

Normalizes an array to have zero mean and unit variance.

Corner Case Note: This transform computes the mean and variance dynamically across the provided input x at evaluation time. It does not store running statistics (like a BatchNorm layer).

Attributes:

Name Type Description
axis int | tuple[int, ...] | None

The axis or axes along which to compute the statistics.

epsilon float

A small scalar added to the variance to prevent division by zero.

Parameters:

Name Type Description Default
axis Union[int, tuple[int, ...], None]

The axis along which to normalize. If None, normalizes across the entire flattened array. Defaults to None.

None
epsilon float

Small scalar to prevent division by zero. Defaults to 1e-5.

1e-05
Source code in parax/transforms.py
296
297
298
299
300
301
302
303
304
def __init__(self, axis: Union[int, tuple[int, ...], None] = None, epsilon: float = 1e-5):
    """
    Args:
        axis: The axis along which to normalize. If None, normalizes 
            across the entire flattened array. Defaults to None.
        epsilon: Small scalar to prevent division by zero. Defaults to 1e-5.
    """
    self.axis = axis
    self.epsilon = epsilon

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

The normalized array with mean 0 and variance 1 along the specified axis.

Source code in parax/transforms.py
306
307
308
309
310
311
312
313
314
315
316
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        The normalized array with mean 0 and variance 1 along the specified axis.
    """
    mean = jnp.mean(x, axis=self.axis, keepdims=True)
    variance = jnp.var(x, axis=self.axis, keepdims=True)
    return (x - mean) * jax.lax.rsqrt(variance + self.epsilon)

parax.transforms.Chain(transforms)

Bases: AbstractTransform

Composes a sequence of transformations into a single transformation.

The transformations are applied in reverse order to match standard mathematical function composition. Mathematically, Chain([f, g, h])(x) is equivalent to f(g(h(x))).

Attributes:

Name Type Description
transforms tuple[AbstractTransform, ...]

A tuple containing the sequence of transformations.

Parameters:

Name Type Description Default
transforms list[AbstractTransform] | tuple[AbstractTransform, ...]

A sequence (list or tuple) of AbstractTransform instances. They will be applied from right-to-left (last element first).

required
Source code in parax/transforms.py
332
333
334
335
336
337
338
339
def __init__(self, transforms: list[AbstractTransform] | tuple[AbstractTransform, ...]):
    """
    Args:
        transforms: A sequence (list or tuple) of `AbstractTransform` instances.
            They will be applied from right-to-left (last element first).
    """
    # Convert to tuple to guarantee immutability as an Equinox PyTree node
    self.transforms = tuple(transforms)

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

The array mapped sequentially through all transformations in reverse order.

Source code in parax/transforms.py
341
342
343
344
345
346
347
348
349
350
351
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        The array mapped sequentially through all transformations in reverse order.
    """
    for transform in reversed(self.transforms):
        x = transform(x)
    return x

parax.transforms.BijectorTransform

Bases: AbstractTransform

A transformation powered by a distreqx bijector.

Applies the forward pass of a given distreqx.bijectors.AbstractBijector. This is the standard bridge for injecting complex, mathematically rigorous bijections into derived variables.

Attributes:

Name Type Description
bijector AbstractBijector

The underlying distreqx bijector.

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

The array mapped through the bijector's forward pass.

Source code in parax/transforms.py
367
368
369
370
371
372
373
374
375
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        The array mapped through the bijector's forward pass.
    """
    return self.bijector.forward(x)

parax.transforms.CustomTransform(fn)

Bases: AbstractTransform

An escape hatch for power users who need to apply an arbitrary function as a transformation while strictly adhering to the AbstractTransform type hierarchy.

Corner Case Note: The underlying callable is marked as static=True. This prevents JAX from attempting to flatten raw Python functions (like lambdas). If your custom transformation requires learnable arrays or state, you should subclass AbstractTransform directly instead of using this wrapper.

Attributes:

Name Type Description
_custom_fn Callable

The internal, user-defined callable.

Parameters:

Name Type Description Default
fn Callable

The custom callable. Must accept a single array argument and return a transformed array.

required
Source code in parax/transforms.py
447
448
449
450
451
452
453
454
455
456
def __init__(
    self, 
    fn: Callable
):
    """
    Args:
        fn: The custom callable. Must accept a single array argument 
            and return a transformed array.
    """
    self._custom_fn = fn

__call__(x)

Parameters:

Name Type Description Default
x Array

The input array.

required

Returns:

Type Description
Array

The array mapped through the custom function.

Source code in parax/transforms.py
458
459
460
461
462
463
464
465
466
def __call__(self, x: Array) -> Array:
    """
    Args:
        x: The input array.

    Returns:
        The array mapped through the custom function.
    """
    return self._custom_fn(x)

parax.transforms.TreeTransform(transforms)

Bases: AbstractTransform

Represents a PyTree of transformations mapping over a PyTree of inputs.

Useful for applying heterogeneous transformations to complex nested structures (like equinox.Module instances) simultaneously.

Attributes:

Name Type Description
tree PyTree[AbstractTransform]

The PyTree containing AbstractTransform leaves.

Parameters:

Name Type Description Default
transforms PyTree[AbstractTransform]

A PyTree containing AbstractTransform leaves. Non-transform leaves are ignored.

required

Raises:

Type Description
ValueError

If the provided PyTree contains no transform leaves.

Source code in parax/transforms.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def __init__(
    self, 
    transforms: PyTree[AbstractTransform],
):
    """
    Args:
        transforms: A PyTree containing `AbstractTransform` leaves.
            Non-transform leaves are ignored.

    Raises:
        ValueError: If the provided PyTree contains no transform leaves.
    """
    # Local import prevents circular dependency at initialization time
    from parax.filters import is_transform

    leaves = jax.tree.leaves(transforms, is_leaf=is_transform)
    if not leaves:
        raise ValueError("The pytree of transforms cannot be empty.")

    self.tree = transforms

__call__(x)

Maps each leaf transform over the corresponding node in the input PyTree.

Parameters:

Name Type Description Default
x PyTree[Array]

The input PyTree of arrays. Must have a matching tree prefix to the transforms PyTree.

required

Returns:

Type Description
PyTree[Array]

A new PyTree containing the transformed arrays.

Source code in parax/transforms.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
def __call__(self, x: PyTree[Array]) -> PyTree[Array]:
    """
    Maps each leaf transform over the corresponding node in the input PyTree.

    Args:
        x: The input PyTree of arrays. Must have a matching tree prefix 
            to the `transforms` PyTree.

    Returns:
        A new PyTree containing the transformed arrays.
    """
    from parax.filters import is_transform

    def _apply_transform(transform: Any, val: Any) -> Any:
        if not is_transform(transform):
            return val
        return transform(val)

    return jax.tree_util.tree_map(_apply_transform, self.tree, x, is_leaf=is_transform)