Skip to content

API Reference

parax.Parameter(value=None, fixed=False, metadata=None, n=None, **kwargs)

Bases: Module

A container for a parameter.

This class serves as the fundamental building block for defining parameters with metadata within Equinox modules. It is designed to be a flexible container that behaves like a standard JAX array (i.e.., a jax.numpy.ndarray) while holding additional metadata for model training and analysis.

Usage
  • Use in mathematical operations just like a JAX/numpy array.
  • Parameter objects are JAX PyTrees, compatible with JAX transformations (jit, grad).
  • Mark as fixed (honored by parax.partition).
  • Associate distributions and transforms/bijectors using distreqx.

During initialization, core metadata and arbitrary kwargs are automatically routed into the hidden ParameterMetadata struct. If a transform/bijector is provided, the input value is assumed to be in the physical (constrained) space and is automatically inverted to store the latent (unconstrained) value.

The parameter n allows for vectorizing the input value and metadata across n dimensions.

Source code in parax/parameter.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(
    self, 
    value: Any | None = None, 
    fixed: bool = False, 
    metadata: ParameterMetadata | None = None, 
    n: int | None = None,
    **kwargs
):
    """
    During initialization, core metadata and arbitrary kwargs are automatically routed into the 
    hidden `ParameterMetadata` struct. If a transform/bijector is provided, the input 
    `value` is assumed to be in the physical (constrained) space and is 
    automatically inverted to store the latent (unconstrained) value.

    The parameter `n` allows for vectorizing the input value and metadata 
    across `n` dimensions.
    """
    latent_value = kwargs.pop('latent_value', None)

    if latent_value is None and value is None:
        raise Exception("Must pass one of either `latent_value` or `value` to Parameter constructor")

    # 1. Handle Vectorization (n)
    if n is not None:
        if value is not None:
            value = jnp.broadcast_to(jnp.asarray(value), (n,) + jnp.shape(value))
        if latent_value is not None:
            latent_value = jnp.broadcast_to(jnp.asarray(latent_value), (n,) + jnp.shape(latent_value))

    # 2. Extract known metadata keys
    updates = {}
    for key in ["name", "distribution", "transform", "bounds", "scale"]:
        if key in kwargs:
            updates[key] = kwargs.pop(key)

    # Handle vectorization for specific metadata fields
    if n is not None and n != 1:
        # Vectorize bounds: if shape is (2,), it becomes (n, 2)
        if "bounds" in updates and updates["bounds"] is not None:
            b = jnp.asarray(updates["bounds"])
            updates["bounds"] = jnp.broadcast_to(b, (n,) + b.shape)

        # Vectorize name: if a string is passed, turn into a list of n strings
        if "name" in updates and isinstance(updates["name"], str):
            updates["name"] = [f"{updates['name']}_{i}" for i in range(n)]

    # Format specific fields
    if "name" in updates and isinstance(updates["name"], tuple):
        updates["name"] = list(updates["name"])

    if "bounds" in updates and updates["bounds"] is not None:
        updates["bounds"] = jnp.asarray(updates["bounds"])

    # Any remaining kwargs belong in the custom 'info' dict
    info_updates = kwargs if len(kwargs) > 0 else {}

    # 3. Reconcile metadata
    if metadata is not None:
        if updates or info_updates:
            new_info = dict(metadata.info) if metadata.info is not None else {}
            new_info.update(info_updates)

            self.metadata = dataclasses.replace(
                metadata, 
                **updates, 
                info=new_info if new_info else None
            )
        else:
            self.metadata = metadata
    else:
        name = updates.get("name", None)
        distribution = updates.get("distribution", None)
        transform = updates.get("transform", None)
        bounds = updates.get("bounds", None)
        scale = updates.get("scale", 1.0)

        if (distribution is None and transform is None and bounds is None and 
            scale == 1.0 and name is None and not info_updates):
            self.metadata = None
        else:
            self.metadata = ParameterMetadata(
                name=name,
                distribution=distribution,
                transform=transform,
                bounds=bounds,
                scale=scale,
                info=info_updates if info_updates else None
            )

    # 4. Handle latent value extraction/inversion
    if latent_value is None:
        latent_value = jnp.asarray(value)
        if self.metadata is not None and self.metadata.transform is not None:
            # The bijector in distreqx handles vectorized inputs automatically
            latent_value = self.metadata.transform.inverse(latent_value)

            if jnp.any(jnp.isnan(latent_value)):
                raise ValueError(f"Got nan while applying bijector inverse in parameter init.")

    self.latent_value = latent_value
    self.fixed = fixed

bounds property

Get the parameter bounds.

Returns:

Type Description
ndarray or None

The physical bounds of the parameter.

distribution property

Get the parameter distribution.

Returns:

Type Description
AbstractDistribution or None

The probability distribution associated with the parameter.

info property

Get the parameter's custom metadata.

Returns:

Type Description
dict

Any arbitrary keyword arguments passed during initialization.

latent_distribution property

Get the parameter distribution in the latent space.

Returns:

Type Description
AbstractDistribution or None

The physical probability distribution mapped back to the latent space via the inverse of the parameter's transform.

name property

Get the parameter name.

Returns:

Type Description
str, list of str, or None

The name or list of names associated with the parameter.

scale property

Get the parameter scale.

Returns:

Type Description
float

The multiplier applied to the physical value for numerical operations.

shape property

Get the shape of the parameter.

Returns:

Type Description
tuple of int

The shape of the latent array.

size property

Get the number of elements in the parameter.

Returns:

Type Description
int

The total size of the latent array.

transform property

Get the parameter transform.

Returns:

Type Description
AbstractBijector or None

The bijector used to map between latent and physical space.

value property

Get the unscaled physical space value.

Returns:

Type Description
ndarray

The parameter value mapped through the transform (if any), but unscaled.

__add__(other)

Elementwise addition.

Source code in parax/parameter.py
596
597
598
def __add__(self, other):
    """Elementwise addition."""
    return jnp.add(jnp.array(self), jnp.array(other))

__array__(dtype=None)

NumPy array interface.

Returns:

Type Description
ndarray

The fully scaled and physical space array.

Source code in parax/parameter.py
534
535
536
537
538
539
540
541
542
543
def __array__(self, dtype=None):
    """
    NumPy array interface.

    Returns
    -------
    numpy.ndarray
        The fully scaled and physical space array.
    """
    return jnp.asarray(self.value * self.scale, dtype=dtype)

__jax_array__(dtype=None)

JAX array interface.

Returns:

Type Description
ndarray

The fully scaled and physical space array.

Source code in parax/parameter.py
545
546
547
548
549
550
551
552
553
554
def __jax_array__(self, dtype=None):
    """
    JAX array interface.

    Returns
    -------
    jnp.ndarray
        The fully scaled and physical space array.
    """
    return jnp.asarray(self.value * self.scale, dtype=dtype)

__len__()

Get the length of the parameter value.

Returns:

Type Description
int

1 for scalars, otherwise len(latent_value).

Source code in parax/parameter.py
556
557
558
559
560
561
562
563
564
565
566
567
def __len__(self):
    """
    Get the length of the parameter value.

    Returns
    -------
    int
        `1` for scalars, otherwise `len(latent_value)`.
    """
    if len(self.latent_value.shape) == 0:
        return 1 
    return len(self.latent_value)

__mul__(other)

Elementwise multiplication.

Source code in parax/parameter.py
604
605
606
def __mul__(self, other):
    """Elementwise multiplication."""
    return jnp.multiply(jnp.array(self), jnp.array(other))

__radd__(other)

Reflected elementwise addition.

Source code in parax/parameter.py
612
613
614
def __radd__(self, other):
    """Reflected elementwise addition."""
    return jnp.add(jnp.array(other), jnp.array(self))

__rmul__(other)

Reflected elementwise multiplication.

Source code in parax/parameter.py
620
621
622
def __rmul__(self, other):
    """Reflected elementwise multiplication."""
    return jnp.multiply(jnp.array(other), jnp.array(self))

__rsub__(other)

Reflected elementwise subtraction.

Source code in parax/parameter.py
616
617
618
def __rsub__(self, other):
    """Reflected elementwise subtraction."""
    return jnp.subtract(jnp.array(other), jnp.array(self))

__rtruediv__(other)

Reflected elementwise true division.

Source code in parax/parameter.py
624
625
626
def __rtruediv__(self, other):
    """Reflected elementwise true division."""
    return jnp.divide(jnp.array(other), jnp.array(self))

__sub__(other)

Elementwise subtraction.

Source code in parax/parameter.py
600
601
602
def __sub__(self, other):
    """Elementwise subtraction."""
    return jnp.subtract(jnp.array(self), jnp.array(other))

__truediv__(other)

Elementwise true division.

Source code in parax/parameter.py
608
609
610
def __truediv__(self, other):
    """Elementwise true division."""
    return jnp.divide(jnp.array(self), jnp.array(other))

as_fixed()

Return a copy of the parameter set to fixed.

Returns:

Type Description
Parameter

A copy with fixed=True.

Source code in parax/parameter.py
512
513
514
515
516
517
518
519
520
521
def as_fixed(self) -> 'Parameter':
    """
    Return a copy of the parameter set to fixed.

    Returns
    -------
    Parameter
        A copy with `fixed=True`.
    """
    return dataclasses.replace(self, fixed=True)

as_free()

Return a copy of the parameter set to free.

Returns:

Type Description
Parameter

A copy with fixed=False.

Source code in parax/parameter.py
523
524
525
526
527
528
529
530
531
532
def as_free(self) -> 'Parameter':
    """
    Return a copy of the parameter set to free.

    Returns
    -------
    Parameter
        A copy with `fixed=False`.
    """
    return dataclasses.replace(self, fixed=False)

copy()

Return a shallow copy.

Returns:

Type Description
Parameter

A copied instance.

Source code in parax/parameter.py
628
629
630
631
632
633
634
635
636
637
def copy(self):
    """
    Return a shallow copy.

    Returns
    -------
    Parameter
        A copied instance.
    """
    return dataclasses.replace(self)

flattened(separator='_')

Flatten the parameter into a list of scalar Parameters.

If the internal parameter is scalar, the list will contain self. Otherwise, the parameter is split (de-vectorized) if possible.

Parameters:

Name Type Description Default
separator str

Separator used for naming split parameters (e.g., name_0), by default '_'.

'_'

Returns:

Type Description
list of Parameter

The list of individual scalar parameters.

Raises:

Type Description
ValueError

If the list of names does not match the parameter size.

Source code in parax/parameter.py
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
def flattened(self, separator='_') -> 'list[Parameter]':
    """
    Flatten the parameter into a list of scalar Parameters.

    If the internal parameter is scalar, the list will contain self.
    Otherwise, the parameter is split (de-vectorized) if possible.

    Parameters
    ----------
    separator : str, optional
        Separator used for naming split parameters (e.g., name_0), by default '_'.

    Returns
    -------
    list of Parameter
        The list of individual scalar parameters.

    Raises
    ------
    ValueError
        If the list of names does not match the parameter size.
    """
    if self.latent_value.ndim == 0 and not isinstance(self.name, list):
        return [self]

    unscaled_physical = self.value
    flat_val = jnp.ravel(unscaled_physical)

    if self.distribution is not None:
        if not self.distribution.event_shape:
            dists_split = [self.distribution] * flat_val.size
        else:
            dists_split = split_vectorized_distribution(self.distribution)
    else:
        dists_split = [None] * flat_val.size

    if isinstance(self.name, list):
        if len(self.name) != flat_val.size:
            raise ValueError(f"Length of name list ({len(self.name)}) must match parameter size ({flat_val.size}).")
        flat_names = self.name
    elif self.name is not None:
        flat_names = [f"{self.name}{separator}{i}" for i in range(flat_val.size)]
    else:
        flat_names = [None] * flat_val.size

    return [
        Parameter(
            value=val, 
            fixed=self.fixed, 
            distribution=p, 
            transform=self.transform,
            bounds=self.bounds,
            scale=self.scale, 
            name=n,
            **self.info 
        ) 
        for val, p, n in zip(flat_val, dists_split, flat_names)
    ]

from_json(s) classmethod

Deserialize a parameter from a JSON string.

Parameters:

Name Type Description Default
s str

The JSON string produced by to_json.

required

Returns:

Type Description
Parameter

A reconstructed Parameter instance.

Source code in parax/parameter.py
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
@classmethod
def from_json(cls, s: str) -> "Parameter":
    """
    Deserialize a parameter from a JSON string.

    Parameters
    ----------
    s : str
        The JSON string produced by `to_json`.

    Returns
    -------
    Parameter
        A reconstructed `Parameter` instance.
    """
    d = json.loads(s)

    raw_value = d.pop("value", None)
    value = deserialize_array(raw_value)
    fixed = d.pop("fixed", False)

    if "distribution" in d:
        d["distribution"] = deserialize_distribution(d["distribution"])

    if "transform" in d:
        d["transform"] = deserialize_transform(d["transform"])

    info_dict = d.pop("info", {})
    d.update(info_dict)

    return cls(value=value, fixed=fixed, **d)

to_json()

Serialize the parameter to a JSON string.

Omits any fields that are None or empty to keep the payload lightweight.

Returns:

Type Description
str

A JSON-formatted string containing the parameter's data.

Source code in parax/parameter.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
def to_json(self) -> str:
    """
    Serialize the parameter to a JSON string.

    Omits any fields that are None or empty to keep the payload lightweight.

    Returns
    -------
    str
        A JSON-formatted string containing the parameter's data.
    """
    d = {
        "fixed": self.fixed,
        "scale": self.scale
    }

    d["value"] = serialize_array(self.value) if self.latent_value is not None else None

    if self.distribution is not None:
        d["distribution"] = serialize_distribution(self.distribution)

    if self.transform is not None:
        d["transform"] = serialize_transform(self.transform)

    if self.bounds is not None:
        d["bounds"] = self.bounds.tolist()

    if self.name is not None:
        d["name"] = self.name

    if self.info: 
        d["info"] = self.info 

    return json.dumps(d, indent=2)

transformed(transform)

Return a copy of this parameter transformed.

This method applies the given transform to the parameter's physical space. It holistically updates the parameter by chaining the new transform with any existing one, transforming the probability distribution, and mapping the bounds. The underlying latent unconstrained value remains unchanged.

Parameters:

Name Type Description Default
transform AbstractBijector

The transform to apply to the parameter's unscaled physical space.

required

Returns:

Type Description
Parameter

A dynamically transformed Parameter object.

Raises:

Type Description
TypeError

If the provided transform is not an instance of AbstractBijector.

Source code in parax/parameter.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def transformed(self, transform: AbstractBijector) -> 'Parameter':
    """
    Return a copy of this parameter transformed.

    This method applies the given transform to the parameter's physical space. 
    It holistically updates the parameter by chaining the new transform with 
    any existing one, transforming the probability distribution, and mapping 
    the bounds. The underlying latent unconstrained value remains unchanged.

    Parameters
    ----------
    transform : distreqx.bijectors.AbstractBijector
        The transform to apply to the parameter's unscaled physical space.

    Returns
    -------
    Parameter
        A dynamically transformed Parameter object.

    Raises
    ------
    TypeError
        If the provided transform is not an instance of AbstractBijector.
    """
    if not isinstance(transform, AbstractBijector):
        raise TypeError("The provided transformation must be a distreqx AbstractBijector.")
    if self.latent_value is None:
        raise Exception("Cannot transform a parameter with a None latent value")

    # 1. Transform the distribution
    new_dist = self.distribution
    if new_dist is not None:
        new_dist = Transformed(new_dist, transform)

    # 2. Chain the transforms (applied right-to-left: first old, then new)
    old_transform = self.transform
    if old_transform is not None:
        new_transform = Chain([transform, old_transform])
    else:
        new_transform = transform

    # 3. Transform the bounds
    new_bounds = self.bounds
    if new_bounds is not None:
        new_bounds = transform.forward(new_bounds)

    # 4. Update metadata
    if self.metadata is None:
        new_meta = ParameterMetadata(
            distribution=new_dist,
            transform=new_transform,
            bounds=new_bounds
        )
    else:
        new_meta = dataclasses.replace(
            self.metadata, 
            distribution=new_dist,
            transform=new_transform,
            bounds=new_bounds
        )

    # The latent value remains unchanged; the chained transform handles the new physical mapping.
    return dataclasses.replace(self, metadata=new_meta)

with_distribution(distribution)

Return a copy of the parameter with a new distribution.

Parameters:

Name Type Description Default
distribution AbstractDistribution

The distribution to associate with this parameter.

required

Returns:

Type Description
Parameter

A copy of this object with the distribution replaced.

Raises:

Type Description
Exception

If distribution is not a distreqx AbstractDistribution.

Source code in parax/parameter.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
def with_distribution(self, distribution: AbstractDistribution) -> 'Parameter':
    """
    Return a copy of the parameter with a new distribution.

    Parameters
    ----------
    distribution : distreqx.distributions.AbstractDistribution
        The distribution to associate with this parameter.

    Returns
    -------
    Parameter
        A copy of this object with the `distribution` replaced.

    Raises
    ------
    Exception
        If `distribution` is not a distreqx AbstractDistribution.
    """
    if not isinstance(distribution, AbstractDistribution):
        raise Exception('Only distreqx distributions are supported as parameter distributions')

    if self.metadata is None:
        new_meta = ParameterMetadata(distribution=distribution)
    else:
        new_meta = dataclasses.replace(self.metadata, distribution=distribution)

    return dataclasses.replace(self, metadata=new_meta)

with_name(name)

Return a copy of the parameter with a new physical name.

Parameters:

Name Type Description Default
name str

The new name.

required

Returns:

Type Description
Parameter

A copy of this object with name updated.

Source code in parax/parameter.py
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
def with_name(self, name: str) -> 'Parameter':
    """
    Return a copy of the parameter with a new physical name.

    Parameters
    ----------
    name : str
        The new name.

    Returns
    -------
    Parameter
        A copy of this object with `name` updated.
    """
    if self.metadata is None:
        new_meta = ParameterMetadata(name=name)
    else:
        new_meta = dataclasses.replace(self.metadata, name=name)

    return dataclasses.replace(self, metadata=new_meta)

with_transform(transform)

Return a copy of the parameter with a new transform.

Parameters:

Name Type Description Default
transform AbstractBijector

The transform to associate with this parameter.

required

Returns:

Type Description
Parameter

A copy of this object with the transform replaced.

Raises:

Type Description
Exception

If distribution is not a distreqx AbstractDistribution.

Source code in parax/parameter.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def with_transform(self, transform: AbstractBijector) -> 'Parameter':
    """
    Return a copy of the parameter with a new transform.

    Parameters
    ----------
    transform : distreqx.bijectors.AbstractBijector
        The transform to associate with this parameter.

    Returns
    -------
    Parameter
        A copy of this object with the `transform` replaced.

    Raises
    ------
    Exception
        If `distribution` is not a distreqx AbstractDistribution.
    """
    if not isinstance(transform, AbstractBijector):
        raise Exception('Only distreqx bijectors are supported as parameter transforms')

    if self.metadata is None:
        new_meta = ParameterMetadata(transform=transform)
    else:
        new_meta = dataclasses.replace(self.metadata, transform=transform)

    return dataclasses.replace(self, metadata=new_meta)    

with_value(value)

Return a copy of the parameter with a new physical value.

Parameters:

Name Type Description Default
value ndarray

The new unscaled physical value to set. It will be mapped through the transform inverse if one exists.

required

Returns:

Type Description
Parameter

A copy of this object with value updated.

Source code in parax/parameter.py
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def with_value(self, value: jnp.ndarray) -> 'Parameter':
    """
    Return a copy of the parameter with a new physical value.

    Parameters
    ----------
    value : jnp.ndarray
        The new unscaled physical value to set. It will be mapped through 
        the transform inverse if one exists.

    Returns
    -------
    Parameter
        A copy of this object with `value` updated.
    """
    latent_value = jnp.asarray(value)
    if self.metadata is not None and self.metadata.transform is not None:
        latent_value = self.metadata.transform.inverse(latent_value)        
    return dataclasses.replace(self, latent_value=latent_value)

parax.ParameterMetadata

Bases: Module

Hidden struct to hold all parameter metadata.

This keeps the core Parameter class lightweight for basic users by compartmentalizing the extended properties that parax interacts with. It also contains an info field to store arbitrary user-defined metadata.

Attributes:

Name Type Description
name str, list, or None

The identifier(s) for the parameter. Must either be a single string or a list matching the shape of the underlying array.

distribution AbstractDistribution or None

The probability distribution associated with the parameter in unscaled physical space.

transform AbstractBijector or None

The transform used to map from the latent space to the unscaled physical space.

bounds ndarray or None

The boundaries of the parameter in unscaled physical space. Can be used as hard constraints for bounded optimizers.

scale float

A scalar multiplier applied to the unscaled physical value to convert it to a JAX array to be used in calculations. Defaults to 1.0.

info dict

A dictionary for storing additional, arbitrary user-defined metadata. Marked as static.

parax.ParameterGroup(param_names, name=None, distribution=None, transform=None, info=field(default_factory=dict, static=True)) dataclass

A metadata class that groups a set of named flat parameters and defines any joint relationships, distributions, or transforms between them.

Attributes:

Name Type Description
param_names list of str

The names of the parameters included in this group.

name (str or None, optional)

An optional identifier for the group itself (e.g., 'covariance_matrix').

distribution (AbstractDistribution or None, optional)

An optional joint distribution over the grouped parameters.

transform (AbstractBijector or None, optional)

An optional joint transform applied to the grouped parameters. This is provided for future compatibility and is not yet used.

info dict

Arbitrary user-defined metadata associated with the group. Marked as static.

num_params property

Get the number of flattened parameters in the group.

Returns:

Type Description
int

The count of names in param_names.

transformed(transform)

Return a copy of this parameter group transformed by an additional joint bijector.

This method chains the new bijector with any existing group-level bijector, applying the transformations sequentially.

Parameters:

Name Type Description Default
transform AbstractBijector

The transform to apply to the group.

required

Returns:

Type Description
ParameterGroup

A dynamically transformed ParameterGroup object.

Raises:

Type Description
TypeError

If the provided bijector is not a distreqx AbstractBijector.

Source code in parax/parameter_group.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def transformed(self, transform: AbstractBijector) -> 'ParameterGroup':
    """
    Return a copy of this parameter group transformed by an additional joint bijector.

    This method chains the new bijector with any existing group-level bijector,
    applying the transformations sequentially.

    Parameters
    ----------
    transform : distreqx.bijectors.AbstractBijector
        The transform to apply to the group.

    Returns
    -------
    ParameterGroup
        A dynamically transformed ParameterGroup object.

    Raises
    ------
    TypeError
        If the provided bijector is not a distreqx AbstractBijector.
    """
    if not isinstance(transform, AbstractBijector):
        raise TypeError("The provided transformation must be a distreqx AbstractBijector.")

    new_transform = self.transform
    if new_transform is not None:
        new_transform = Chain([transform, new_transform])
    else:
        new_transform = transform

    return dataclasses.replace(self, bijector=new_transform)

with_distribution(distribution)

Return a copy of the parameter group with a new joint distribution.

Parameters:

Name Type Description Default
distribution AbstractDistribution

The joint distribution to associate with this parameter group.

required

Returns:

Type Description
ParameterGroup

A copy of this object with the distribution replaced.

Raises:

Type Description
TypeError

If distribution is not a distreqx AbstractDistribution.

Source code in parax/parameter_group.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def with_distribution(self, distribution: AbstractDistribution) -> 'ParameterGroup':
    """
    Return a copy of the parameter group with a new joint distribution.

    Parameters
    ----------
    distribution : distreqx.distributions.AbstractDistribution
        The joint distribution to associate with this parameter group.

    Returns
    -------
    ParameterGroup
        A copy of this object with the `distribution` replaced.

    Raises
    ------
    TypeError
        If `distribution` is not a distreqx AbstractDistribution.
    """
    if not isinstance(distribution, AbstractDistribution):
        raise TypeError('Only distreqx distributions are supported as parameter distributions')

    return dataclasses.replace(self, distribution=distribution)

parax.partition(pytree, include_fixed=False, include_arrays=False, param_objects=False)

Partitions an arbitrary PyTree into (dynamic, static) halves.

By default, this acts as a "strict" parameter partitioner: ONLY non-fixed [~parax.Parameter][] objects are routed to the dynamic tree. Raw JAX arrays are treated as static data unless explicitly requested.

Parameters:

Name Type Description Default
pytree T

The PyTree to partition.

required
include_fixed bool

If True, includes [~parax.Parameter][] objects where fixed=True.

False
include_arrays bool

If True, standard JAX floating-point arrays (not wrapped in a [~parax.Parameter][]) are ALSO routed to the dynamic tree.

False
param_objects bool

If True, the entire [~parax.Parameter][] object is routed to the dynamic tree. If False, ONLY the underlying .latent_value array is routed to the dynamic tree.

False

Returns:

Type Description
tuple of T

A tuple containing (dynamic, static) PyTrees.

Source code in parax/tree.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def partition(
    pytree: T, 
    include_fixed: bool = False, 
    include_arrays: bool = False,
    param_objects: bool = False,
) -> tuple[T, T]:
    """
    Partitions an arbitrary PyTree into (dynamic, static) halves.

    By default, this acts as a "strict" parameter partitioner: ONLY non-fixed 
    [`~parax.Parameter`][] objects are routed to the dynamic tree. Raw JAX arrays are 
    treated as static data unless explicitly requested.

    Parameters
    ----------
    pytree : T
        The PyTree to partition.
    include_fixed : bool, default=False
        If True, includes [`~parax.Parameter`][] objects where `fixed=True`.
    include_arrays : bool, default=False
        If True, standard JAX floating-point arrays (not wrapped in a 
        [`~parax.Parameter`][]) are ALSO routed to the dynamic tree.
    param_objects : bool, default=False
        If True, the entire [`~parax.Parameter`][] object is routed to the dynamic tree. 
        If False, ONLY the underlying `.latent_value` array is routed to the dynamic tree.

    Returns
    -------
    tuple of T
        A tuple containing `(dynamic, static)` PyTrees.
    """

    def build_mask(node):
        # 1. Parameter Logic
        if is_valid_param(node):
            if not include_fixed and getattr(node, "fixed", False):
                return False 

            if param_objects:
                return True
            else:
                false_param = jax.tree_util.tree_map(lambda _: False, node)
                return eqx.tree_at(lambda p: p.latent_value, false_param, True)

        # 2. Raw Array Logic (The Escape Hatch)
        if include_arrays and eqx.is_array(node):
            # Only treat floating point arrays as dynamic (standard JAX/Equinox behavior)
            return jax.numpy.issubdtype(node.dtype, jax.numpy.inexact)

        # 3. Everything else is static
        return False

    # Build the filter spec
    filter_spec = jax.tree_util.tree_map(build_mask, pytree, is_leaf=is_valid_param)

    # Preserve Parameter objects if requested
    leaf_fn = is_valid_param if param_objects else None

    return eqx.partition(pytree, filter_spec, is_leaf=leaf_fn)

parax.Module

Bases: Module

An extension of an Equinox Module.

This class extends an Equinox Module with additional helpful features and methods.

One feature includes the ability to inspect and modify parameters using strings based on their module path. This is helpful for modifying deep, hierachical modules using unique identifiers.

Another feature is the fact that attributes marked with the Parameter type are automatically given parameter-converters. This ensures that they remain parameters after construction (e.g. when initializing them with a float).

Usage
Methods & Properties Summary

Introspection Properties

Module Inspection & Manipulation

Method Description
children Returns the immediate submodules.
submodules Returns all nested submodules (depth-first).
sampled Return a new module with parameters drawn from this module's distribution.

Parameter Inspection

Method Description
named_params Named module parameter objects as a dict.
named_param_values Named module parameter values as a dict of jax arrays.
param_names Module parameter names as a list.
param A single module parameter object by name.
params Module parameters as a list.
param_value A single module parameter value by name.
param_values Module parameter values as a list of jax arrays.
named_flat_params Named flattened module parameter objects as a dict.
named_flat_param_values Named flattened module parameter values as a dict.
flat_param_names Flattened parameter names as a list.
flat_params Flattened parameters as a list.
flat_param_values Flattened module parameter values as a flat array.
param_groups Return all parameter groups relevant to this module.

Parameter Manipulation

Method Description
with_params Return a module with parameters updated.
with_mapped_params Apply a map function to parameters.
with_transformed_params Apply a map function to parameters.
with_fixed_params Return a module with specified parameters fixed.
with_free_params Return a module with specified parameters free.
with_free_params_only Return a module with ONLY the specified parameters free.
with_all_params_fixed Return a module with all parameters fixed.
with_all_params_free Return a module with all parameters free.

Parameter Group Manipulation

Method Description
with_param_groups Return a module with parameter groups appended.
with_demoted_param_groups Recursively demote parameter groups to deepest submodule.
with_no_param_groups Return a module with all parameter groups removed.

Distribution Manipulation

Method Description
with_mapped_distributions Apply a map function to the parameter distributions.
with_uniform_distributions Return a module with uniform distributions set.

Field & Module Manipulation

Method Description
with_defaults Return this module type with default initialization args.
[with_modules][parax.Module.with_modules] Combines this module with free parameters in other modules.
with_fields Return a copy with dataclass-style field replacements.
with_name Return a copy of this module with a different name.
with_submodule_fields Dataclass-style field replacements on a nested sub-module.
with_free_submodules Free all parameters in the given submodules.
with_free_submodules_only Returns a module with ONLY the specified submodules freed.
with_fixed_submodules Fix all parameters in the given submodules.

Function Tools

Method Description
func_jacobian Calculate the Jacobian of a function w.r.t parameters.
func_sensitivity Calculate the sensitivity of a function w.r.t parameters.
func_samples Evaluate a function over parameter samples.

Attributes:

Name Type Description
name str or None

An optional name for the module instance.

num_flat_params property

Number of free, flattened parameters.

Returns:

Type Description
int

num_params property

Number of free parameters.

Returns:

Type Description
int

__init_subclass__(transparent=False, **kwargs)

Customize subclass construction.

Source code in parax/module.py
206
207
208
209
def __init_subclass__(cls, transparent: bool = False, **kwargs):
    """Customize subclass construction."""        
    super().__init_subclass__(**kwargs)
    cls._transparent = transparent

children()

Returns the immediate submodules.

Returns:

Type Description
list[Module]
Source code in parax/module.py
465
466
467
468
469
470
471
472
def children(self) -> list['Module']:
    """Returns the immediate submodules.

    Returns
    -------
    list[Module]
    """
    return [node for node in eqx.tree_flatten_one_level(self)[0] if isinstance(node, Module)]

copy()

Returns a deepcopy of self.

Returns:

Type Description
Module
Source code in parax/module.py
417
418
419
420
421
422
423
424
def copy(self: Self) -> Self:
    """Returns a deepcopy of self.

    Returns
    -------
    Module
    """        
    return deepcopy(self)   

flat_param_names(*args, **kwargs)

Return flattened parameter names as a list.

See parax.Module.named_flat_params.

Source code in parax/module.py
655
656
657
658
659
660
661
def flat_param_names(self, *args, **kwargs) -> list[str]:
    """
    Return flattened parameter names as a list.

    See [`parax.Module.named_flat_params`][].
    """
    return list(self.named_flat_params(*args, **kwargs).keys())    

flat_param_values(*args, **kwargs)

Return flattened module parameter values as a jax arrays.

See parax.Module.named_flat_param_values.

Source code in parax/module.py
671
672
673
674
675
676
677
def flat_param_values(self, *args, **kwargs) -> jnp.ndarray:
    """
    Return flattened module parameter values as a jax arrays.

    See [`parax.Module.named_flat_param_values`][].
    """
    return jnp.array(list(self.named_flat_param_values(*args, **kwargs).values())).reshape(-1)

flat_params(*args, **kwargs)

Return flattened parameters as a list.

See parax.Module.named_flat_params.

Source code in parax/module.py
663
664
665
666
667
668
669
def flat_params(self, *args, **kwargs) -> list[Parameter]:
    """
    Return flattened parameters as a list.

    See [`parax.Module.named_flat_params`][].
    """
    return list(self.named_flat_params(*args, **kwargs).values())

func_jacobian(func, args)

Calculate the Jacobian of an arbitrary function with respect to free parameters.

This uses forward-mode automatic differentiation to compute the gradients of the provided function with respect to each free parameter in the module.

Parameters:

Name Type Description Default
func Callable[[Module], ndarray]

Function to differentiate. Must take a Module and args and return a jnp.ndarray of any shape.

required
args Any

The args to pass to func.

required

Returns:

Type Description
dict[str, ndarray]

A dictionary mapping flat parameter names to their gradient arrays. Each array has the same shape as the output of func.

Source code in parax/module.py
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
@eqx.filter_jit
def func_jacobian(
    self: Self, 
    func: Callable[['Module'], jnp.ndarray], 
    args: Any
) -> dict[str, jnp.ndarray]:
    """Calculate the Jacobian of an arbitrary function with respect to free parameters.

    This uses forward-mode automatic differentiation to compute the gradients 
    of the provided function with respect to each free parameter in the module.

    Parameters
    ----------
    func : Callable[[Module], jnp.ndarray]
        Function to differentiate. Must take a Module and args
        and return a jnp.ndarray of any shape.
    args : Any
        The args to pass to `func`.

    Returns
    -------
    dict[str, jnp.ndarray]
        A dictionary mapping flat parameter names to their gradient 
        arrays. Each array has the same shape as the output of `func`.
    """
    def func_from_flat(flat_params_array: jnp.ndarray) -> jnp.ndarray:
        sampled_module = self.with_params(flat_params_array)
        return func(sampled_module, args)

    jac_array = jax.jacfwd(func_from_flat)(self.flat_param_values())
    jac_moved = jnp.moveaxis(jac_array, -1, 0)
    param_names = self.flat_param_names()

    return {name: jac_moved[i] for i, name in enumerate(param_names)}

func_samples(func, args, *, key, num_samples=1000)

Evaluates an arbitrary function over samples drawn from the module's distribution.

Parameters:

Name Type Description Default
func Callable[[Module], ndarray]

A function that takes a Module instance and returns a JAX array.

required
args Any

The args to pass to func.

required
key Array

JAX random key for sampling.

required
num_samples int

Number of modules to sample from the joint distribution.

1000

Returns:

Type Description
ndarray

The function evaluated over all samples. Shape will be (num_samples, *func_output_shape).

Source code in parax/module.py
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
@eqx.filter_jit
def func_samples(
    self, 
    func: Callable[['Module'], jnp.ndarray], 
    args: Any,
    *,
    key: jax.Array, 
    num_samples: int = 1000
) -> jnp.ndarray:
    """
    Evaluates an arbitrary function over samples drawn from the 
    module's distribution.

    Parameters
    ----------
    func : Callable[[Module], jnp.ndarray]
        A function that takes a Module instance and returns a JAX array.
    args : Any
        The args to pass to `func`.            
    key : jax.Array
        JAX random key for sampling.
    num_samples : int, default=1000
        Number of modules to sample from the joint distribution.

    Returns
    -------
    jnp.ndarray
        The function evaluated over all samples. Shape will be 
        (num_samples, *func_output_shape).
    """
    dist = self.flat_distribution()
    flat_param_samples = dist.sample(key, sample_shape=(num_samples,))

    def evaluate_single(flat_params_array):
        sampled_module = self.with_params(flat_params_array)
        return func(sampled_module, args)

    return jax.vmap(evaluate_single)(flat_param_samples)  

func_sensitivity(func, args, kind='relative', norm=None)

Calculate the sensitivity of an arbitrary function w.r.t parameters.

Supported kinds: - 'relative': (dy/dtheta) * (theta/y). Fractional change in output per fractional change in parameter. Blows up if y is zero. - 'semi-relative': (dy/dtheta) * theta. Change in output per fractional change in parameter. Stable if y is zero. - 'absolute': (dy/dtheta). Raw gradient.

Parameters:

Name Type Description Default
func Callable[[Module], ndarray]

Function to evaluate.

required
args Any

The args to pass to func.

required
kind str

The type of sensitivity to calculate ('relative', 'semi-relative', 'absolute').

'relative'
norm int | str | None

If provided, aggregates the parameter sensitivities into a single scalar metric using the specified norm (e.g., 2 for L2 norm, jnp.inf for max norm).

None

Returns:

Type Description
dict[str, ndarray] | ndarray

If norm is None, returns a dictionary mapping flat parameter names to sensitivity arrays. If norm is specified, returns a 0D scalar jax array representing the global sensitivity metric.

Source code in parax/module.py
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
@eqx.filter_jit
def func_sensitivity(
    self: Self, 
    func: Callable[['Module'], jnp.ndarray], 
    args: Any,
    kind: str = 'relative',
    norm: int | str | None = None
) -> dict[str, jnp.ndarray] | jnp.ndarray:
    r"""Calculate the sensitivity of an arbitrary function w.r.t parameters.

    Supported kinds:
    - 'relative': (dy/dtheta) * (theta/y). Fractional change in output per 
      fractional change in parameter. Blows up if y is zero.
    - 'semi-relative': (dy/dtheta) * theta. Change in output per 
      fractional change in parameter. Stable if y is zero.
    - 'absolute': (dy/dtheta). Raw gradient.

    Parameters
    ----------
    func : Callable[[Module], jnp.ndarray]
        Function to evaluate.
    args : Any
        The args to pass to `func`.
    kind : str, default='relative'
        The type of sensitivity to calculate ('relative', 'semi-relative', 'absolute').
    norm : int | str | None, default=None
        If provided, aggregates the parameter sensitivities into a single scalar 
        metric using the specified norm (e.g., 2 for L2 norm, jnp.inf for max norm).

    Returns
    -------
    dict[str, jnp.ndarray] | jnp.ndarray
        If `norm` is None, returns a dictionary mapping flat parameter names 
        to sensitivity arrays.
        If `norm` is specified, returns a 0D scalar jax array representing 
        the global sensitivity metric.
    """
    def func_from_flat(flat_params_array: jnp.ndarray) -> jnp.ndarray:
        sampled_module = self.with_params(flat_params_array)
        return func(sampled_module, args)

    theta = self.flat_param_values()
    jac_array = jax.jacfwd(func_from_flat)(theta)

    if kind == 'absolute':
        sens_array = jac_array

    elif kind == 'semi-relative':
        sens_array = jac_array * theta

    elif kind == 'relative':
        y_nom = func(self, args)
        y_safe = jnp.where(y_nom == 0, 1e-15, y_nom)
        sens_array = jac_array * (theta / y_safe[..., None])

    else:
        raise ValueError(f"Unknown sensitivity kind: '{kind}'. "
                         f"Supported: 'relative', 'semi-relative', 'absolute'") 

    if norm is not None:
        return jnp.linalg.norm(sens_array, ord=norm)

    sens_moved = jnp.moveaxis(sens_array, -1, 0)
    param_names = self.flat_param_names()

    return {name: sens_moved[i] for i, name in enumerate(param_names)}

iter_params(param_filter=None, *, include_fixed=False, flatten=False, submodules=None)

Iterate over (name, Parameter) pairs in internal order.

Source code in parax/module.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def iter_params(
    self,
    param_filter: str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool] = None,
    *,
    include_fixed: bool = False,
    flatten: bool = False,
    submodules: 'Module | Sequence[Module] | str | Sequence[str] | None' = None,
) -> Iterator[tuple[str, Parameter]]:
    """Iterate over (name, Parameter) pairs in internal order."""

    # 1. Direct Flattening
    path_and_nodes, _ = jax.tree.flatten_with_path(self, is_leaf=is_valid_param)

    # 2. Pre-process submodule IDs outside the loop (Fast JAX C++ traversal)
    allowed_param_ids = None
    if submodules is not None:
        if isinstance(submodules, (Module, str)):
            submodules = [submodules]

        resolved_submodules = [getattr(self, sm) if isinstance(sm, str) else sm for sm in submodules]
        if not isinstance(resolved_submodules[0], Module):
            raise Exception(f"Got unknown type when expecting a module or string. Type was: {type(resolved_submodules[0])}")

        allowed_param_ids = set()
        for sm in resolved_submodules:
            sm_nodes, _ = jax.tree.flatten(sm, is_leaf=is_valid_param)
            for node in sm_nodes:
                if is_valid_param(node) and (include_fixed or not getattr(node, "fixed", False)):
                    allowed_param_ids.add(id(node))

    # 3. Pre-process filters into O(1) lookups
    filter_is_seq_str = False
    filter_is_seq_param = False
    filter_is_callable = False
    filter_ids = None

    if param_filter is not None:
        if isinstance(param_filter, str):
            param_filter = {param_filter} 
            filter_is_seq_str = True
        elif isinstance(param_filter, Parameter):
            filter_ids = {id(param_filter)}
            filter_is_seq_param = True
        elif isinstance(param_filter, Sequence):
            if len(param_filter) > 0:
                if isinstance(param_filter[0], str):
                    param_filter = set(param_filter) 
                    filter_is_seq_str = True
                elif isinstance(param_filter[0], Parameter):
                    filter_ids = {id(p) for p in param_filter}
                    filter_is_seq_param = True
        elif isinstance(param_filter, Callable):
            filter_is_callable = True
        else:
            raise Exception(f"Unknown filter type passed for parameters: {param_filter}")

    # 4. The Single Lazy Pass
    for path, param in path_and_nodes:
        if not is_valid_param(param):
            continue
        if not include_fixed and getattr(param, "fixed", False):
            continue

        if allowed_param_ids is not None and id(param) not in allowed_param_ids:
            continue

        if filter_is_seq_param and id(param) not in filter_ids:
            continue

        name = self.path_to_param_name(path)

        if filter_is_seq_str and name not in param_filter:
            continue
        if filter_is_callable and not param_filter(name):
            continue

        # 5. Flattening & Yielding
        if flatten and (param.size > 1 or isinstance(param.name, list)):
            flattened_params = param.flattened(separator=self._separator)
            for i, subparam in enumerate(flattened_params):
                suffix = subparam.name if subparam.name is not None else str(i)
                yield f"{name}{self._separator}{suffix}", subparam
        else:
            yield name, param

merged(modules)

Merge this module with free parameters and parameter groups in other modules.

This is useful to combine separate modules obtained from fitting the same initial module with different free parameters.

Parameters:

Name Type Description Default
modules Module or Sequence[Module]

The other modules to combine this module with.

required

Returns:

Type Description
Module
Source code in parax/module.py
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
def merged(self: Self, modules: Self | Sequence[Self]) -> Self:
    """Merge this module with free parameters and parameter groups
    in other modules.

    This is useful to combine separate modules obtained from fitting
    the same initial module with different free parameters.

    Parameters
    ----------
    modules : Module or Sequence[Module]
        The other modules to combine this module with.

    Returns
    -------
    Module
    """  
    if not isinstance(modules, Sequence):
        modules = [modules]

    combined = self
    for other in modules:
        combined = combined.with_params(other.named_params())
        combined = combined.with_param_groups(other.param_groups(explicit_only=True))
    return combined    

named_flat_param_values(scaled=False, return_floats=False, **kwargs)

Named flattened module parameter values as a dict of jax arrays.

See parax.Module.named_flat_params.

Parameters:

Name Type Description Default
scaled bool

Whether or not to scale the returned values by the parameter scales.

False
**kwargs

Additional key-word arguments as in parax.Module.named_params.

{}

Returns:

Type Description
dict[str, ndarray]
Source code in parax/module.py
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
def named_flat_param_values(self, scaled=False, return_floats=False, **kwargs) -> dict[str, jnp.ndarray]:
    """Named flattened module parameter values as a dict of jax arrays.

    See [`parax.Module.named_flat_params`][].

    Parameters
    ----------
    scaled : bool, default=False
        Whether or not to scale the returned values by the parameter scales.
    **kwargs
        Additional key-word arguments as in  [`parax.Module.named_params`][].

    Returns
    -------
    dict[str, jnp.ndarray]
    """     
    if scaled:
        retval = {n: jnp.array(p) for n, p in (self.iter_params(flatten=True, **kwargs))}
    else:
        retval = {n: p.latent_value for n, p in (self.iter_params(flatten=True, **kwargs))}

    if return_floats:
        import numpy as np
        retval = {k: float(np.array(v)) for k, v in retval.items()}
    return retval

named_flat_params(include_fixed=False, submodules=None)

Named flattened module parameters as a dict.

Flat parameters are a de-vectorized version of the internal parameters of the module. The returned parameter objects therefore are not necessarily equal to the internal module objects.

Keys are fully-qualified parameter names with de-vectorized suffixes added. The order matches the internal flattened array order.

Parameters:

Name Type Description Default
include_fixed bool

Include fixed parameters.

False
submodules Module | Sequence[Module] | str | Sequence[str] | None

Restrict to parameters used by the given submodule(s). If strings are provided, getattr(self, name) is used.

None

Returns:

Type Description
dict[str, Parameter]
Source code in parax/module.py
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
def named_flat_params(self, include_fixed=False, submodules: 'Module' | Sequence['Module'] | str | Sequence[str] | None = None) -> dict[str, Parameter]:
    """Named flattened module parameters as a dict.

    Flat parameters are a de-vectorized version of
    the internal parameters of the module. The returned
    parameter objects therefore are not necessarily
    equal to the internal module objects.

    Keys are fully-qualified parameter names with de-vectorized suffixes added.
    The order matches the internal flattened array order.

    Parameters
    ----------
    include_fixed : bool, default=False
        Include fixed parameters.
    submodules : Module | Sequence[Module] | str | Sequence[str] | None, optional
        Restrict to parameters used by the given submodule(s). If strings are
        provided, ``getattr(self, name)`` is used.

    Returns
    -------
    dict[str, Parameter]
    """
    return dict(self.iter_params(flatten=True, include_fixed=include_fixed, submodules=submodules))

named_param_values(scaled=False, **kwargs)

Named module parameter values as a dict of jax arrays.

See parax.Module.named_params.

Parameters:

Name Type Description Default
scaled bool

Whether or not to scale the returned values by the parameter scales.

False
**kwargs

Additional key-word arguments as in parax.Module.named_params.

{}

Returns:

Type Description
dict[str, ndarray]
Source code in parax/module.py
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def named_param_values(self, scaled=False, **kwargs) -> dict[str, jnp.ndarray]:
    """Named module parameter values as a dict of jax arrays.

    See [`parax.Module.named_params`][].

    Parameters
    ----------
    scaled : bool, default=False
        Whether or not to scale the returned values by the parameter scales.
    **kwargs
        Additional key-word arguments as in  [`parax.Module.named_params`][].

    Returns
    -------
    dict[str, jnp.ndarray]
    """     
    if scaled:
        return {n: jnp.array(p) for n, p in (self.iter_params(**kwargs))}
    else:
        return {n: p.latent_value for n, p in (self.iter_params(**kwargs))}    

named_params(param_filter=None, *, include_fixed=False, submodules=None)

Named module parameters as a dict.

Keys are fully-qualified parameter names. The order matches the internal flattened array order.

Parameters:

Name Type Description Default
param_filter str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool]

A filter indicating which parameters to return. For the default case, all parameters are returned.

None
include_fixed bool

Include fixed parameters.

False
submodules Module | Sequence[Module] | str | Sequence[str] | None

Restrict to parameters used by the given submodule(s). If strings are provided, getattr(self, name) is used.

None

Returns:

Type Description
dict[str, Parameter]
Source code in parax/module.py
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def named_params(self, param_filter: str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool] = None, *, include_fixed=False, submodules: 'Module' | Sequence['Module'] | str | Sequence[str] | None = None) -> dict[str, Parameter]:
    """Named module parameters as a dict.

    Keys are fully-qualified parameter names.
    The order matches the internal flattened array order.

    Parameters
    ----------
    param_filter : str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool], default=None
        A filter indicating which parameters to return. For the default case, all parameters are returned.
    include_fixed : bool, default=False
        Include fixed parameters.
    submodules : Module | Sequence[Module] | str | Sequence[str] | None, optional
        Restrict to parameters used by the given submodule(s). If strings are
        provided, ``getattr(self, name)`` is used.

    Returns
    -------
    dict[str, Parameter]
    """
    return dict(self.iter_params(param_filter=param_filter, include_fixed=include_fixed, submodules=submodules))

param(name, *args, **kwargs)

Return a single module parameter by name.

See parax.Module.named_params.

Source code in parax/module.py
572
573
574
575
576
577
578
def param(self, name: str, *args, **kwargs) -> Parameter:
    """
    Return a single module parameter by name.

    See [`parax.Module.named_params`][].
    """
    return self.named_params(*args, **kwargs)[name]

param_groups(include_fixed=False, explicit_only=False)

Return all parameter groups relevant to this module, including submodules.

This function recursively traverses submodules to collect their parameter groups, adjusting parameter names to match the current module's scope.

Priority is given to groups defined in the parent module. If a parameter is grouped explicitly in self._param_groups, it will be removed from any groups returned by submodules.

Parameters:

Name Type Description Default
include_fixed bool

Include groups involving fixed parameters.

False

Returns:

Type Description
list[ParameterGroup]
Source code in parax/module.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
def param_groups(self, include_fixed=False, explicit_only=False) -> list[ParameterGroup]:
    """Return all parameter groups relevant to this module, including submodules.

    This function recursively traverses submodules to collect their parameter groups,
    adjusting parameter names to match the current module's scope.

    Priority is given to groups defined in the parent module. If a parameter is 
    grouped explicitly in `self._param_groups`, it will be removed from any 
    groups returned by submodules.

    Parameters
    ----------
    include_fixed : bool, default=False
        Include groups involving fixed parameters.

    Returns
    -------
    list[ParameterGroup]
    """
    if explicit_only:
        return deepcopy(self._param_groups)

    all_valid_params = self.named_flat_params(include_fixed=include_fixed)
    valid_param_names = set(all_valid_params.keys())

    groups = []
    for group in self._param_groups:
        if not set(group.param_names).isdisjoint(valid_param_names):
            groups.append(deepcopy(group))

    path_and_nodes, _ = jax.tree_util.tree_flatten_with_path(
        self, 
        is_leaf=lambda x: isinstance(x, Module) and x is not self
    )

    for path, node in path_and_nodes:
        if isinstance(node, Module) and node is not self:
            relative_name = self.path_to_param_name(path)
            prefix = f"{relative_name}{self._separator}" if relative_name else ""
            sub_groups = node.param_groups(include_fixed=include_fixed)

            for sub_group in sub_groups:
                new_names = [prefix + name for name in sub_group.param_names]
                lifted_group = dataclasses.replace(sub_group, param_names=new_names)
                groups.append(lifted_group)

    final_groups = []
    seen_params = set()

    for group in groups:
        valid_names = [name for name in group.param_names if name not in seen_params]
        if valid_names:
            if len(valid_names) != len(group.param_names):
                group = dataclasses.replace(group, param_names=valid_names)
            final_groups.append(group)
            seen_params.update(valid_names)

    for name, param in all_valid_params.items():
        if name not in seen_params:
            final_groups.append(ParameterGroup(param_names=[name], distribution=param.distribution))
            seen_params.add(name)

    return final_groups

param_names(*args, **kwargs)

Return module parameter names as a list.

See parax.Module.named_params.

Source code in parax/module.py
564
565
566
567
568
569
570
def param_names(self, *args, **kwargs) -> list[str]:
    """
    Return module parameter names as a list.

    See [`parax.Module.named_params`][].
    """
    return list(self.named_params(*args, **kwargs).keys())

param_value(name, *args, **kwargs)

Return a single module parameter value by name as a single jax array.

See parax.Module.named_param_values.

Source code in parax/module.py
588
589
590
591
592
593
594
def param_value(self, name: str, *args, **kwargs) -> jnp.ndarray:
    """
    Return a single module parameter value by name as a single jax array.

    See [`parax.Module.named_param_values`][].
    """
    return self.named_param_values(*args, **kwargs)[name]

param_values(*args, **kwargs)

Return module parameter values as a list of jax arrays.

See parax.Module.named_param_values.

Source code in parax/module.py
596
597
598
599
600
601
602
def param_values(self, *args, **kwargs) -> list[jnp.ndarray]:
    """
    Return module parameter values as a list of jax arrays.

    See [`parax.Module.named_param_values`][].
    """
    return list(self.named_param_values(*args, **kwargs).values())

params(*args, **kwargs)

Return module parameters as a list.

See parax.Module.named_params.

Source code in parax/module.py
580
581
582
583
584
585
586
def params(self, *args, **kwargs) -> list[Parameter]:
    """
    Return module parameters as a list.

    See [`parax.Module.named_params`][].
    """
    return list(self.named_params(*args, **kwargs).values())

path_to_param_name(path)

Convert a PyTree path to a fully-qualified parameter name.

Source code in parax/module.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def path_to_param_name(self, path) -> str:
    """Convert a PyTree path to a fully-qualified parameter name."""
    name_fields = []
    node = self

    for item in path:
        if isinstance(item, GetAttrKey):
            k = item.name
            next_node = getattr(node, k)

            # 1. Determine transparency
            is_transparent = getattr(node, '_transparent', False)
            if not is_transparent and is_dataclass(node):
                field_obj = next((f for f in fields(node) if f.name == k), None)
                if field_obj is not None:
                    is_transparent = field_obj.metadata.get('transparent', False)

            # 2. Extract user override
            explicit_name = getattr(next_node, 'name', None)

            # 3. Rule application
            if is_transparent:
                if explicit_name is not None:
                    name_fields.append(explicit_name)
            else:
                name_fields.append(explicit_name if explicit_name is not None else k)

            node = next_node

        elif isinstance(item, DictKey):
            k = item.key
            node = node[k]
            name_fields.append(str(k))

        elif isinstance(item, (SequenceKey, FlattenedIndexKey)):
            idx = item.idx if hasattr(item, 'idx') else item.key
            node = node[idx]
            explicit_name = getattr(node, 'name', None)
            if explicit_name is not None:
                name_fields.append(explicit_name)
            else:
                name_fields.append(str(idx))

        else:
            raise Exception(f"Unsupported key type in path: {type(item)}")

    return self._separator.join(name_fields)

sampled(key=None, **kwargs)

Returns a new module with parameters sampled from this parameter's distribution.

Returns:

Type Description
Module
Source code in parax/module.py
483
484
485
486
487
488
489
490
491
492
def sampled(self, key=None, **kwargs) -> 'Module':
    """Returns a new module with parameters sampled from this parameter's distribution.

    Returns
    -------
    Module
    """
    dist = self.flat_distribution()
    flat_param_samples = dist.sample(key, sample_shape=(1,))[0]
    return self.with_params(flat_param_samples)

submodules()

Returns all nested submodules (depth-first), excluding self.

Returns:

Type Description
list[Module]
Source code in parax/module.py
474
475
476
477
478
479
480
481
def submodules(self) -> list['Module']:
    """Returns all nested submodules (depth-first), excluding ``self``.

    Returns
    -------
    list[Module]
    """
    return nodes_by_type(self, Module)[1:]         

with_all_params_fixed(**kwargs)

Returns a module with all parameters fixed.

This is an alias for calling parax.Module.with_free_params with fix_others=True and no parameters passed.

See parax.Module.with_free_params.

Source code in parax/module.py
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
def with_all_params_fixed(self: Self, **kwargs) -> Self:
    """Returns a module with all parameters fixed.

    This is an alias for calling [`parax.Module.with_free_params`][]
    with `fix_others=True` and no parameters passed.

    See [`parax.Module.with_free_params`][].
    """
    kwargs.setdefault('fix_others', True)
    if kwargs['fix_others'] == False:
        raise Exception("Cannot pass fix_others == False for `with_all_params_fixed`.")
    return self.with_free_params({}, **kwargs)

with_all_params_free(**kwargs)

Returns a module with all parameters free.

This is an alias for calling parax.Module.with_free_params with all parameters passed.

See parax.Module.with_free_params.

Source code in parax/module.py
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
def with_all_params_free(self: Self, **kwargs) -> Self:
    """Returns a module with all parameters free.

    This is an alias for calling [`parax.Module.with_free_params`][]
    with all parameters passed.

    See [`parax.Module.with_free_params`][].
    """
    kwargs.setdefault('include_fixed', True)
    if kwargs['include_fixed'] == False:
        raise Exception("Cannot pass include_fixed == False for `with_all_params_free`.")        
    return self.with_free_params(self.param_names(include_fixed=True), **kwargs)

with_attrs(*args, **kwargs)

Return a copy of the module with one or more attributes replaced.

This is similar to eqx.tree_at but uses string paths.

Usage

1. Single attribute update (path, value)

model.with_attrs('a.b.c', 10)

2. Batch nested updates via dictionary

model.with_attrs({'a.b.c': 10, 'x.y.z': 20})

3. Top-level attributes via keyword arguments

model.with_attrs(name="new_model", _transparent=True)

4. Combined dict and kwargs

model.with_attrs({'a.b.c': 10}, name="new_model")

Source code in parax/module.py
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
def with_attrs(self: Self, *args: Any, **kwargs: Any) -> Self:
    """
    Return a copy of the module with one or more attributes replaced.

    This is similar to `eqx.tree_at` but uses string paths.

    Usage
    -----
    # 1. Single attribute update (path, value)
    model.with_attrs('a.b.c', 10)

    # 2. Batch nested updates via dictionary
    model.with_attrs({'a.b.c': 10, 'x.y.z': 20})

    # 3. Top-level attributes via keyword arguments
    model.with_attrs(name="new_model", _transparent=True)

    # 4. Combined dict and kwargs
    model.with_attrs({'a.b.c': 10}, name="new_model")
    """
    all_updates = {}

    # Parse positional arguments
    if len(args) == 2 and isinstance(args[0], str):
        all_updates[args[0]] = args[1]
    elif len(args) == 1 and isinstance(args[0], dict):
        all_updates.update(args[0])
    elif len(args) > 0:
        raise ValueError(
            "Invalid positional arguments. Please provide either a single "
            "(path, value) pair or a dictionary of updates."
        )

    # Add any top-level kwargs
    all_updates.update(kwargs)

    # Fast exit if nothing to update
    if not all_updates:
        return self

    # Extract paths and their corresponding values in a consistent order
    paths = tuple(all_updates.keys())
    values = tuple(all_updates.values())

    # CORRECTED: A single callable that returns a tuple of nodes
    def where_fn(tree):
        return tuple(operator.attrgetter(p)(tree) for p in paths)

    return eqx.tree_at(where_fn, self, values)

with_defaults(*args, **kwargs) classmethod

Return this module type with default initialization arguments.

This method is very useful in utilizing an existing module with default values, without having to create a new module type via inheritance.

Arguments are forwarded as if they were passed to __init__.

Returns:

Type Description
type[Module]
Source code in parax/module.py
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
@classmethod
def with_defaults(cls, *args, **kwargs) -> type[Self]:
    """Return this module type with default initialization arguments.

    This method is very useful in utilizing an existing module
    with default values, without having to create a new
    module type via inheritance.

    Arguments are forwarded as if they were passed to `__init__`.

    Returns
    -------
    type[Module]
    """            
    class DefaultsWrapper:
        def __init__(self, p):
            self.p = p   

        def __call__(self, *call_args, **call_kwargs):
            baked_args = deepcopy(self.p.args)
            baked_kwargs = deepcopy(self.p.keywords)

            final_args = baked_args + call_args
            final_kwargs = {**baked_kwargs, **call_kwargs}

            return self.p.func(*final_args, **final_kwargs)

        def with_defaults(self, *new_args, **new_kwargs):
            merged_args = self.p.args + new_args
            merged_kwargs = {**self.p.keywords, **new_kwargs} if self.p.keywords else new_kwargs
            return DefaultsWrapper(partial(self.p.func, *merged_args, **merged_kwargs))

    return DefaultsWrapper(partial(cls, *args, **kwargs))    

with_demoted_param_groups()

Recursively demote parameter groups to the deepest possible submodule.

This method identifies parameter groups where every parameter belongs to the same immediate submodule. It moves those groups to the submodule, stripping the prefix. It then recursively calls this method on the submodules to ensure groups continue moving down the hierarchy as far as possible.

Returns:

Type Description
Self

A new module instance with parameter groups distributed to their lowest relevant submodules.

Source code in parax/module.py
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
def with_demoted_param_groups(self: Self) -> Self:
    """Recursively demote parameter groups to the deepest possible submodule.

    This method identifies parameter groups where every parameter belongs to the same 
    immediate submodule. It moves those groups to the submodule, stripping the prefix.
    It then recursively calls this method on the submodules to ensure groups continue 
    moving down the hierarchy as far as possible.

    Returns
    -------
    Self
        A new module instance with parameter groups distributed to their lowest 
        relevant submodules.
    """
    submodule_prefixes = {} 
    for f in dataclasses.fields(self):
        if isinstance(getattr(self, f.name), Module):
            prefix = f.name + self._separator
            submodule_prefixes[prefix] = f.name

    groups_to_keep = []
    submodule_groups = {name: [] for name in submodule_prefixes.values()}

    current_groups = self._param_groups if self._param_groups is not None else []

    for group in current_groups:
        demoted = False
        for prefix, field_name in submodule_prefixes.items():
            if all(name.startswith(prefix) for name in group.param_names):
                new_names = [name[len(prefix):] for name in group.param_names]
                new_group = dataclasses.replace(group, param_names=new_names)
                submodule_groups[field_name].append(new_group)
                demoted = True
                break

        if not demoted:
            groups_to_keep.append(group)

    new_fields = {}
    for prefix, field_name in submodule_prefixes.items():
        child_module: Module = getattr(self, field_name)

        groups_to_push = submodule_groups[field_name]
        if groups_to_push:
            child_module = child_module.with_param_groups(groups_to_push)

        child_module = child_module.with_demoted_param_groups()
        new_fields[field_name] = child_module

    new_module = self.with_fields(**new_fields)
    object.__setattr__(new_module, '_param_groups', groups_to_keep)
    return new_module

with_fields(*args, **kwargs)

Return a copy of this module with dataclass-style field replacements.

Parameters are forwarded to dataclasses.replace.

Source code in parax/module.py
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
def with_fields(self: Self, *args, **kwargs) -> Self:
    """
    Return a copy of this module with dataclass-style field replacements.

    Parameters are forwarded to `dataclasses.replace`.
    """
    new_module = dataclasses.replace(self, *args, **kwargs)

    for f in dataclasses.fields(self):
        if not f.init:
            val = getattr(self, f.name)
            object.__setattr__(new_module, f.name, deepcopy(val))

    return new_module

with_fixed_params(param_filter, free_others=False, **kwargs)

Return a module with specified parameters fixed.

This maps each parameter in the filter, calling parax.Parameter.as_fixed on each.

See parax.Module.with_mapped_params.

Parameters:

Name Type Description Default
free_others bool

Also free all parameters not in the filter.

False

Returns:

Type Description
Self
Source code in parax/module.py
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
def with_fixed_params(self: Self, param_filter: str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool], free_others: bool = False, **kwargs) -> Self:
    """Return a module with specified parameters fixed.

    This maps each parameter in the filter, calling [`parax.Parameter.as_fixed`][] on each.

    See [`parax.Module.with_mapped_params`][].

    Parameters
    ----------
    free_others : bool, default=False
        Also free all parameters not in the filter.        

    Returns
    -------
    Self
    """
    map_others = None
    if free_others:
        map_others = lambda p: p.as_free()

    kwargs.setdefault('include_fixed', True) 

    return self.with_mapped_params(lambda p: p.as_fixed(), param_filter=param_filter, map_others=map_others, **kwargs)

with_fixed_submodules(submodules)

Fix all parameters in the given submodules.

Submodules parameters are obtained using parax.Module.param_names., and subsequently fixed using parax.Module.with_fixed_params.

Parameters:

Name Type Description Default
submodules Module | Sequence[Module] | str | Sequence[str]

Submodules whose parameters should be fixed.

required

Returns:

Type Description
Self
Source code in parax/module.py
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
def with_fixed_submodules(self: Self, submodules: 'Module' | Sequence['Module'] | str | Sequence[str]) -> Self:
    """Fix all parameters in the given submodules.

    Submodules parameters are obtained using [`parax.Module.param_names`][].,
    and subsequently fixed using [`parax.Module.with_fixed_params`][].

    Parameters
    ----------
    submodules : Module | Sequence[Module] | str | Sequence[str]
        Submodules whose parameters should be fixed.

    Returns
    -------
    Self
    """        
    module_param_names = self.param_names(include_fixed=True, submodules=submodules)
    return self.with_fixed_params(module_param_names)

with_free_params(param_filter, *, fix_others=False, **kwargs)

Free the specified parameters.

This maps each parameter in the filter, calling parax.Parameter.as_free on each.

See parax.Module.with_mapped_params.

Parameters:

Name Type Description Default
fix_others bool

Also fix all parameters not in the filter.

False

Returns:

Type Description
Self
Source code in parax/module.py
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
def with_free_params(self: Self, param_filter: str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool], *, fix_others: bool = False, **kwargs) -> Self:
    """Free the specified parameters.

    This maps each parameter in the filter, calling [`parax.Parameter.as_free`][] on each.

    See [`parax.Module.with_mapped_params`][].

    Parameters
    ----------
    fix_others : bool, default=False
        Also fix all parameters not in the filter.

    Returns
    -------
    Self
    """
    map_others = None
    if fix_others:
        map_others = lambda p: p.as_fixed()

    kwargs.setdefault('include_fixed', True) 

    return self.with_mapped_params(lambda p: p.as_free(), param_filter=param_filter, map_others=map_others, **kwargs)

with_free_params_only(param_filter, **kwargs)

Returns a module with only the specified parameters freed.

This is an alias for calling parax.Module.with_free_params with fix_others=True.

See parax.Module.with_free_params.

Source code in parax/module.py
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
def with_free_params_only(self: Self, param_filter: str | list[str] | Callable[[str], bool], **kwargs) -> Self:
    """Returns a module with only the specified parameters freed.

    This is an alias for calling [`parax.Module.with_free_params`][]
    with `fix_others=True`.

    See [`parax.Module.with_free_params`][].
    """
    kwargs.setdefault('fix_others', True)
    if kwargs['fix_others'] == False:
        raise Exception("Cannot pass fix_others == False for `with_free_params_only`.")
    return self.with_free_params(param_filter, **kwargs)

with_free_submodules(submodules, fix_others=False, include_fixed=True)

Free all parameters in the given submodules.

Submodules parameters are obtained using parax.Module.param_names., and subsequently freed using parax.Module.with_free_params.

Parameters:

Name Type Description Default
submodules Module | Sequence[Module] | str | Sequence[str]

Submodules whose parameters should be free.

required
include_fixed bool

Include fixed parameters in the submodule.

True
fix_others bool

Fix all other submodules.

False

Returns:

Type Description
Self
Source code in parax/module.py
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
def with_free_submodules(self: Self, submodules: 'Module' | Sequence['Module'] | str | Sequence[str], fix_others=False, include_fixed=True) -> Self:
    """Free all parameters in the given submodules.

    Submodules parameters are obtained using [`parax.Module.param_names`][].,
    and subsequently freed using [`parax.Module.with_free_params`][].

    Parameters
    ----------
    submodules : Module | Sequence[Module] | str | Sequence[str]
        Submodules whose parameters should be free.
    include_fixed : bool, default=True
        Include fixed parameters in the submodule.
    fix_others : bool, default=False
        Fix all other submodules.

    Returns
    -------
    Self
    """        
    module_param_names = self.param_names(include_fixed=include_fixed, submodules=submodules)
    return self.with_free_params(module_param_names, fix_others=fix_others)

with_free_submodules_only(*args, include_fixed=False, **kwargs)

Returns a module with only the specified submodules freed.

This is an alias for calling parax.Module.with_free_submodules with fix_others=True and include_fixed=False by default.

See parax.Module.with_free_params.

Source code in parax/module.py
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
def with_free_submodules_only(self: Self, *args, include_fixed=False, **kwargs) -> Self:
    """Returns a module with only the specified submodules freed.

    This is an alias for calling [`parax.Module.with_free_submodules`][]
    with `fix_others=True` and `include_fixed=False` by default.

    See [`parax.Module.with_free_params`][].
    """     
    kwargs.setdefault('fix_others', True)
    if kwargs['fix_others'] == False:
        raise Exception("Cannot pass fix_others == False for `with_free_submodules_only`.")
    return self.with_free_submodules(*args, include_fixed=include_fixed, **kwargs)

with_mapped_distributions(mapper, dist_filter=None, *, map_others=None, param_groups=False)

Return a module with a function applied to its parameter distributions.

This method allows for bulk-updates of distributions, such as widening variances or changing distribution types.

If param_groups is False, the mapping is applied to the distributions of individual parameters (flattened).

If param_groups is True, the mapping is applied to the distributions of parax.ParameterGroup objects. This mode is recursive: it will traverse the module tree and apply the mapping to all explicit parameter groups in all submodules.

Parameters:

Name Type Description Default
mapper Callable[[AbstractDistribution], AbstractDistribution]

Function that takes a distribution and returns a new one.

required
dist_filter Callable[[AbstractDistribution], bool] | None

A predicate function. If provided, the mapping is only applied to distributions where dist_filter(dist) is True. If None, applies to all.

None
map_others Callable[[AbstractDistribution], AbstractDistribution] | None

An optional map to apply to all distributions NOT in the filter.

None
param_groups bool

If True, map distributions on parameter groups (recursively). If False, map distributions on individual parameters (flat).

False

Returns:

Type Description
Self

A new module with updated distributions.

Source code in parax/module.py
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
def with_mapped_distributions(
    self: Self, 
    mapper: Callable[[AbstractDistribution], AbstractDistribution], 
    dist_filter: Callable[[AbstractDistribution], bool] | None = None, 
    *, 
    map_others: Callable[[AbstractDistribution], AbstractDistribution] | None = None,
    param_groups: bool = False
) -> Self:
    """Return a module with a function applied to its parameter distributions.

    This method allows for bulk-updates of distributions, such as widening variances 
    or changing distribution types.

    If ``param_groups`` is False, the mapping is applied to the distributions 
    of individual parameters (flattened).

    If ``param_groups`` is True, the mapping is applied to the distributions 
    of [`parax.ParameterGroup`][] objects. This mode is recursive: it will traverse 
    the module tree and apply the mapping to all explicit parameter groups in all submodules.

    Parameters
    ----------
    mapper : Callable[[AbstractDistribution], AbstractDistribution]
        Function that takes a distribution and returns a new one.
    dist_filter : Callable[[AbstractDistribution], bool] | None, default=None
        A predicate function. If provided, the mapping is only applied to 
        distributions where ``dist_filter(dist)`` is True. If None, applies to all.
    map_others : Callable[[AbstractDistribution], AbstractDistribution] | None, default=None
        An optional map to apply to all distributions NOT in the filter.
    param_groups : bool, default=False
        If True, map distributions on parameter groups (recursively). 
        If False, map distributions on individual parameters (flat).

    Returns
    -------
    Self
        A new module with updated distributions.
    """
    mapped_module = self

    if param_groups:
        current_groups = self._param_groups if self._param_groups is not None else []
        for group in current_groups:
            if dist_filter is None or dist_filter(group.distribution):
                mapped_module = mapped_module.with_param_groups(group.with_distribution(mapper(group.distribution)))
            elif map_others is not None:
                mapped_module = mapped_module.with_param_groups(group.with_distribution(map_others(group.distribution)))

        new_submodules = {}
        for f in dataclasses.fields(mapped_module):
            child = getattr(mapped_module, f.name)
            if isinstance(child, Module):
                updated_child = child.with_mapped_distributions(
                    mapper, 
                    dist_filter, 
                    map_others=map_others, 
                    param_groups=True
                )
                new_submodules[f.name] = updated_child

        if new_submodules:
            mapped_module = mapped_module.with_fields(**new_submodules)

    else:
        def map_fn(node):
            if is_valid_param(node):
                if dist_filter is None or dist_filter(node.distribution):
                    return node.with_distribution(mapper(node.distribution))
                elif map_others is not None:
                    return node.with_distribution(map_others(node.distribution))
            return node

        mapped_module = jax.tree_util.tree_map(map_fn, self, is_leaf=is_valid_param)

    return mapped_module

with_mapped_params(mapper, param_filter=None, *, map_others=None, prefixes=False, include_fixed=False, ignore_unknown=False)

Return a module with specified parameters mapped.

Parameters:

Name Type Description Default
mapper Callable[[Parameter], Parameter]

The map to apply to each parameter in the filter (or all if no filter).

required
param_filter str | Sequence[str] | Callable[[str], bool] | None

Parameter names to map. If None, applies mapper to all parameters.

None
map_others Callable[[Parameter], Parameter] | None

An optional map to apply to all parameters NOT in the filter.

None
prefixes bool

Specifies that, when a string or list of strings is passed in param_filter, these must be interpreted as parameter prefixes to map and not full path names. Defaults to False.

False

Returns:

Type Description
Self
Source code in parax/module.py
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
def with_mapped_params(
    self: Self, 
    mapper: Callable[[Parameter], Parameter], 
    param_filter: str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool] | None = None, 
    *, 
    map_others: Callable[[Parameter], Parameter] | None = None,
    prefixes: bool = False,
    include_fixed: bool = False,
    ignore_unknown: bool = False,
) -> Self:
    """Return a module with specified parameters mapped.

    Parameters
    ----------
    mapper : Callable[[Parameter], Parameter]
        The map to apply to each parameter in the filter (or all if no filter).
    param_filter : str | Sequence[str] | Callable[[str], bool] | None, default=None
        Parameter names to map. If None, applies mapper to all parameters.
    map_others : Callable[[Parameter], Parameter] | None, default=None
        An optional map to apply to all parameters NOT in the filter.
    prefixes : bool, default=False
        Specifies that, when a string or list of strings is passed
        in `param_filter`, these must be interpreted as parameter prefixes
        to map and not full path names. Defaults to `False.`            

    Returns
    -------
    Self
    """
    current_param_names = set(self.param_names(include_fixed=include_fixed))

    if param_filter is None:
        resolved_filter = current_param_names
    elif isinstance(param_filter, Callable):
        resolved_filter = {p for p in current_param_names if param_filter(p)}
    else:
        # Safely cast single items or dicts to lists
        if isinstance(param_filter, str):
            param_filter = [param_filter]
        elif isinstance(param_filter, Parameter):
            param_filter = [param_filter]
        elif isinstance(param_filter, dict):
            param_filter = list(param_filter.keys())
        else:
            param_filter = list(param_filter)

        # Safely check index 0 only if the list has elements
        if param_filter and isinstance(param_filter[0], str) and prefixes:
            for prefix in param_filter:
                if not any(name.startswith(prefix) for name in current_param_names):
                    if not ignore_unknown:
                        raise ValueError(f"Specified prefix '{prefix}' does not match any parameters in the module")
            valid_prefixes = tuple(param_filter)
            resolved_filter = {p for p in current_param_names if p.startswith(valid_prefixes)}

        elif param_filter and isinstance(param_filter[0], Parameter):
            param_ids = {id(p) for p in param_filter}
            resolved_filter = {name for name, p in self.named_params(include_fixed=include_fixed).items() if id(p) in param_ids}

        else:
            resolved_filter = set(param_filter)

        for param_name in resolved_filter:
            if param_name not in current_param_names:
                raise ValueError(f"Specified parameter '{param_name}' not found in module")

    # Directly map using JAX natively
    def map_fn(path, node):
        if is_valid_param(node):
            if not include_fixed and getattr(node, "fixed", False):
                return node
            name = self.path_to_param_name(path)
            if name in resolved_filter:
                return mapper(node)
            elif map_others is not None:
                return map_others(node)
        return node

    return jax.tree_util.tree_map_with_path(map_fn, self, is_leaf=is_valid_param)

with_name(name)

Return a copy of this module with a different name.

Source code in parax/module.py
1533
1534
1535
1536
1537
def with_name(self: Self, name: str | None) -> Self:
    """
    Return a copy of this module with a different name.
    """
    return self.with_fields(name=name)

with_no_param_groups()

Return a new module with all parameter groups removed recursively.

This clears the _param_groups of the current module and traverses all nested submodules (and sequences of submodules) to remove their parameter groups as well.

Returns:

Type Description
Self

A new module instance with no parameter groups.

Source code in parax/module.py
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
def with_no_param_groups(self: Self) -> Self:
    """Return a new module with all parameter groups removed recursively.

    This clears the `_param_groups` of the current module and traverses
    all nested submodules (and sequences of submodules) to remove their 
    parameter groups as well.

    Returns
    -------
    Self
        A new module instance with no parameter groups.
    """
    new_fields = {} 

    for f in dataclasses.fields(self):
        if f.name == '_param_groups':
            continue

        child = getattr(self, f.name)

        if isinstance(child, Module):
            new_fields[f.name] = child.with_no_param_groups()

        elif isinstance(child, (list, tuple)):
            if any(isinstance(x, Module) for x in child):
                new_fields[f.name] = type(child)(
                    x.with_no_param_groups() if isinstance(x, Module) else x 
                    for x in child
                )

    new_module = self.with_fields(**new_fields)
    object.__setattr__(new_module, '_param_groups', [])
    return new_module

with_param_groups(param_groups)

Return a module with parameter groups appended, replacing existing relationships.

This method implements an "atomic replacement" policy. If any parameter in an existing group is claimed by a new group, the entire existing group is removed.

This ensures that groups defining joint distributions are not left in an invalid broken state (e.g. having a dimension removed). Parameters that were in the removed group but not in the new group will revert to being ungrouped (handled by param_groups as singleton groups).

Parameters:

Name Type Description Default
param_groups ParameterGroup or list[ParameterGroup]

Group(s) to add.

required

Returns:

Type Description
Self
Source code in parax/module.py
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
def with_param_groups(self: Self, param_groups: ParameterGroup | list[ParameterGroup]) -> Self:
    """Return a module with parameter groups appended, replacing existing relationships.

    This method implements an "atomic replacement" policy. If *any* parameter in 
    an existing group is claimed by a new group, the *entire* existing group is 
    removed. 

    This ensures that groups defining joint distributions are not left in an 
    invalid broken state (e.g. having a dimension removed). Parameters that were 
    in the removed group but not in the new group will revert to being ungrouped 
    (handled by `param_groups` as singleton groups).

    Parameters
    ----------
    param_groups : ParameterGroup or list[ParameterGroup]
        Group(s) to add.

    Returns
    -------
    Self
    """       
    if not isinstance(param_groups, list):
        param_groups = [param_groups]

    new_claimed_params = set()
    for group in param_groups:
        new_claimed_params.update(group.param_names)

    current_groups = self._param_groups if self._param_groups is not None else []
    kept_existing_groups = []

    for group in current_groups:
        existing_group_params = set(group.param_names)
        if existing_group_params.isdisjoint(new_claimed_params):
            kept_existing_groups.append(group)

    new_list = kept_existing_groups + param_groups
    new_module = copy(self)
    object.__setattr__(new_module, '_param_groups', new_list)
    return new_module

with_params(params=None, check_missing=False, check_unknown=True, fix_others=False, include_fixed=False, **param_kwargs)

Return a new module with parameters updated.

This is a multi-purpose function that updates parameters differently based on the types pass.

Parameters:

Name Type Description Default
params dict[str, Parameter] | dict[str, float] | ndarray | None

Parameter updates. If an array, all values must be provided (matching flat_params order). You may also pass keyword args.

None
check_missing bool

Require that all module parameters are specified.

False
check_unknown bool

Error if unknown parameter keys are provided.

True
fix_others bool

Fix any parameters not explicitly passed.

False
include_fixed bool

Include fixed parameters when interpreting params mapping.

False
**param_kwargs dict

Additional parameter updates by name.

{}

Returns:

Type Description
Self

Raises:

Type Description
Exception

If shape/order mismatches, unknown/missing names (when checked), or if arrays are found outside of Parameters.

Source code in parax/module.py
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
def with_params(
    self: Self,
    params: dict[str, Parameter] | dict[str, float] | jnp.ndarray | None = None,
    check_missing: bool = False,
    check_unknown: bool = True,
    fix_others = False,
    include_fixed = False,
    **param_kwargs: dict[str, Parameter] | dict[str, float],
) -> Self:
    """Return a new module with parameters updated.

    This is a multi-purpose function that updates parameters differently
    based on the types pass.

    Parameters
    ----------
    params : dict[str, Parameter] | dict[str, float] | jnp.ndarray | None, optional
        Parameter updates. If an array, **all** values must be provided
        (matching ``flat_params`` order). You may also pass keyword args.
    check_missing : bool, default=False
        Require that all module parameters are specified.
    check_unknown : bool, default=True
        Error if unknown parameter keys are provided.
    fix_others : bool, default=False
        Fix any parameters not explicitly passed.
    include_fixed : bool, default=False
        Include fixed parameters when interpreting ``params`` mapping.
    **param_kwargs : dict
        Additional parameter updates by name.

    Returns
    -------
    Self

    Raises
    ------
    Exception
        If shape/order mismatches, unknown/missing names (when checked),
        or if arrays are found outside of Parameters.
    """
    # 1. High-Efficiency Array Update Path
    if params is not None and not isinstance(params, dict):
        if len(param_kwargs) > 0:
            raise ValueError("Cannot pass both a flat array and explicit keyword arguments to with_params.")

        params_array = jnp.asarray(params)

        # Use `eqx.partition` to perfectly isolate parameter values while preserving exact tree structure
        dynamic, static = partition(self, include_fixed=include_fixed, param_objects=False)
        flat_dynamic, unflatten_fn = flatten_util.ravel_pytree(dynamic)

        if flat_dynamic.size != params_array.size:
            raise Exception(f"Array size mismatch: Expected {flat_dynamic.size} elements, but got {params_array.size}.")

        new_dynamic = unflatten_fn(params_array)
        return eqx.combine(new_dynamic, static)

    # 2. Dictionary / Kwargs Update Path
    params = params if params is not None else {}
    params = dict(params)
    params.update(param_kwargs)

    new_params = self.named_params(include_fixed=True)

    parent_keys = set(new_params.keys())
    input_keys = set(params.keys())
    potential_flat_keys = input_keys - parent_keys

    if potential_flat_keys:
        parents_to_scan = [p for p in parent_keys if p not in params]
        for parent_name in parents_to_scan:
            parent_param = new_params[parent_name]
            if parent_param.size > 0: 
                sub_params = parent_param.flattened(separator=self._separator)
                updates_found = False
                new_sub_values = []

                for i, sub_p in enumerate(sub_params):
                    suffix = sub_p.name if sub_p.name is not None else str(i)
                    flat_name = f"{parent_name}{self._separator}{suffix}"

                    if flat_name in params:
                        val = params[flat_name]
                        if hasattr(val, 'item') and getattr(val, "size", 1) == 1:
                            val = val.item()
                        try:
                            val = float(val)
                        except Exception:
                            raise Exception(f"Value for flat parameter '{flat_name}' must be convertible to float. Got: {val}")
                        new_sub_values.append(val)
                        del params[flat_name]
                        updates_found = True
                    else:
                        new_sub_values.append(sub_p.latent_value)

                if updates_found:
                    new_val_flat = jnp.array(new_sub_values)
                    new_val_shaped = new_val_flat.reshape(parent_param.latent_value.shape)

                    # UPDATED: Use .with_value() instead of dataclasses.replace
                    params[parent_name] = parent_param.with_value(new_val_shaped)            

    unknown_params = set(params.keys() - new_params.keys())
    if check_unknown and len(unknown_params) != 0:
        raise Exception(f"Error: the following parameters were passed but are not in the module: {unknown_params}")
    params = {k: v for k, v in params.items() if k not in unknown_params}

    if check_missing or fix_others:
        missing_params = set(new_params.keys() - params.keys())
        if check_missing and len(missing_params) != 0:
            raise Exception(f"Error: the following module parameters were missing: {missing_params}")
        if fix_others:
            for missing_param_name in missing_params:
                new_params[missing_param_name] = dataclasses.replace(new_params[missing_param_name], fixed=True)                    

    for name, value in params.items():
        if isinstance(value, Parameter):
            new_params[name] = value
        else:
            # UPDATED: Route primitive/array updates through .with_value()
            new_params[name] = new_params[name].with_value(jnp.asarray(value))

    # Fast tree mapping bypasses string iteration logic completely
    def map_fn(path, node):
        if is_valid_param(node):
            name = self.path_to_param_name(path)
            if name in new_params:
                new_param = new_params[name]
                if new_param.name is None:
                    new_param = dataclasses.replace(new_param, name=node.name)
                return new_param
        return node

    return jax.tree_util.tree_map_with_path(map_fn, self, is_leaf=is_valid_param)

with_submodule_fields(submodule, *args, **kwargs)

Return a copy of this module with dataclass-style field replacements on a nested sub-module.

Parameters are forwarded to dataclasses.replace.

Parameters:

Name Type Description Default
submodule str | Sequence[str]

The name of the submodule (or sequence of names) to traverse. Can be a single string with a path e.g. 'submodule1.submodule2', or a list of submodules e.g. ['submodule1', 'submodule2'].

required
Source code in parax/module.py
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
def with_submodule_fields(self: Self, submodule: str | Sequence[str], *args, **kwargs) -> Self:
    """
    Return a copy of this module with dataclass-style field replacements on a nested sub-module.

    Parameters are forwarded to `dataclasses.replace`.

    Parameters
    ----------
    submodule : str | Sequence[str]
        The name of the submodule (or sequence of names) to traverse.
        Can be a single string with a path e.g. 'submodule1.submodule2',
        or a list of submodules e.g. ['submodule1', 'submodule2'].
    """
    if isinstance(submodule, str) and submodule.find('.'):
        path = submodule.split('.')
    else:
        path = [submodule] if isinstance(submodule, str) else list(submodule)

    if not path:
        return self.with_fields(*args, **kwargs)

    target_key = path[0]

    if len(path) == 1:
        updated_child = getattr(self, target_key).with_fields(*args, **kwargs)
    else:
        child = getattr(self, target_key)
        updated_child = child.with_submodule_fields(path[1:], *args, **kwargs)

    return self.with_fields(**{target_key: updated_child})  

with_submodules(*args, **kwargs)

Return a copy of the module with one or more submodules replaced.

This method accepts paths formatted in the exact same way as parameter names (e.g. 'submodule1_submodule2_submodule3'), respecting transparency and custom names.

Usage

Single replacement

model.with_submodules('layer1_attention', new_attention_module)

Batch replacement

model.with_submodules({ 'layer1_attention': new_attn_1, 'layer2_attention': new_attn_2 })

Source code in parax/module.py
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
def with_submodules(self: Self, *args: Any, **kwargs: Any) -> Self:
    """
    Return a copy of the module with one or more submodules replaced.

    This method accepts paths formatted in the exact same way as parameter names 
    (e.g. 'submodule1_submodule2_submodule3'), respecting transparency and custom names.

    Usage
    -----
    # Single replacement
    model.with_submodules('layer1_attention', new_attention_module)

    # Batch replacement
    model.with_submodules({
        'layer1_attention': new_attn_1,
        'layer2_attention': new_attn_2
    })
    """
    all_updates = {}

    # Parse positional arguments
    if len(args) == 2 and isinstance(args[0], str):
        all_updates[args[0]] = args[1]
    elif len(args) == 1 and isinstance(args[0], dict):
        all_updates.update(args[0])
    elif len(args) > 0:
        raise ValueError(
            "Invalid positional arguments. Please provide either a single "
            "(path, value) pair or a dictionary of updates."
        )

    # Add any top-level kwargs
    all_updates.update(kwargs)

    if not all_updates:
        return self

    # 1. Gather all submodules and map their string paths to absolute JAX paths
    name_to_jax_path = {}

    def traverse(node, current_path):
        # Flatten exactly one module-level deep to get paths to immediate submodules
        leaves, _ = jax.tree_util.tree_flatten_with_path(
            node, 
            is_leaf=lambda x: isinstance(x, Module) and x is not node
        )
        for sub_path, leaf in leaves:
            if isinstance(leaf, Module):
                full_path = current_path + sub_path
                # Map the absolute JAX path to your custom string format
                str_name = self.path_to_param_name(full_path)
                name_to_jax_path[str_name] = full_path
                # Recurse into the nested module
                traverse(leaf, full_path)

    traverse(self, ())

    # 2. Match requested updates to JAX paths
    paths_to_update = []
    values_to_update = []

    for str_name, new_module in all_updates.items():
        if str_name not in name_to_jax_path:
            raise ValueError(f"Submodule path '{str_name}' not found in module.")
        paths_to_update.append(name_to_jax_path[str_name])
        values_to_update.append(new_module)

    # 3. Create an extractor callable for eqx.tree_at using the JAX paths
    def where_fn(tree):
        extracted = []
        for jax_path in paths_to_update:
            node = tree
            for key in jax_path:
                if isinstance(key, GetAttrKey):
                    node = getattr(node, key.name)
                elif isinstance(key, DictKey):
                    node = node[key.key]
                elif isinstance(key, (SequenceKey, FlattenedIndexKey)):
                    # Compatibility for different JAX versions
                    idx = getattr(key, 'idx', getattr(key, 'key', None))
                    node = node[idx]
            extracted.append(node)
        return tuple(extracted)

    # Execute the atomic replacement
    return eqx.tree_at(where_fn, self, tuple(values_to_update))    

with_transformed_params(bijector, param_filter=None, **kwargs)

Return a module with a distreqx bijector applied to the specified parameters.

This utilizes the underlying transformed method on the matched Parameters, which updates their physical values, bounds, and distributions simultaneously while preserving the unconstrained latent values.

Parameters:

Name Type Description Default
bijector AbstractBijector

The bijector to apply.

required
param_filter str | Sequence[str] | Callable[[str], bool] | None

Parameter names to transform. If None, applies to all parameters.

None

Returns:

Type Description
Self
Source code in parax/module.py
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
def with_transformed_params(
    self: Self, 
    bijector: AbstractBijector, 
    param_filter: str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool] | None = None, 
    **kwargs
) -> Self:
    """
    Return a module with a distreqx bijector applied to the specified parameters.

    This utilizes the underlying `transformed` method on the matched Parameters, 
    which updates their physical values, bounds, and distributions simultaneously 
    while preserving the unconstrained latent values.

    Parameters
    ----------
    bijector : distreqx.bijectors.AbstractBijector
        The bijector to apply.
    param_filter : str | Sequence[str] | Callable[[str], bool] | None, default=None
        Parameter names to transform. If None, applies to all parameters.

    Returns
    -------
    Self
    """
    return self.with_mapped_params(
        mapper=lambda p: p.transformed(bijector), 
        param_filter=param_filter, 
        **kwargs
    )    

with_uniform_distributions(percentage, param_filter=None, *, respect_bounds=False, remove_param_groups=True, zero_values='keep', **kwargs)

Return a module with uniform distributions set centered on current parameter values.

The distributions are defined with bounds calculated as value * (1.0 +/- percentage).

Parameters:

Name Type Description Default
percentage float

The fractional width of the uniform distribution (e.g. 0.1 = 10%).

required
param_filter str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool]

The parameters to be updated with new uniform distributions. For the default case, all are updated.

None
respect_bounds

Whether or not the min and max bounds of the current distributions should be respected. If True, new bounds will not go larger than past these bounds.

False
remove_param_groups

Whether to remove parameter groups recursively when setting the uniform distributions. Otherwise, the joint distribution of the module may not be the desired uniform distribution.

True
zero_values

How to treat zero values. Currently the only option is to keep them and their bounds as is.

'keep'

Returns:

Type Description
Self

A new module with updated parameter distributions.

Source code in parax/module.py
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
def with_uniform_distributions(self, percentage: float, param_filter: str | Sequence[str] | Parameter | Sequence[Parameter] | Callable[[str], bool] = None, *, respect_bounds=False, remove_param_groups=True, zero_values='keep', **kwargs) -> Self:
    """Return a module with uniform distributions set centered on current parameter values.

    The distributions are defined with bounds calculated as ``value * (1.0 +/- percentage)``.

    Parameters
    ----------
    percentage : float
        The fractional width of the uniform distribution (e.g. 0.1 = 10%).
    param_filter: str | Sequence[str] | Callable[[str], bool], default=None
        The parameters to be updated with new uniform distributions. For the default case, all are updated.
    respect_bounds: bool, default=False
        Whether or not the `min` and `max` bounds of the current distributions should be respected.
        If `True`, new bounds will not go larger than past these bounds.
    remove_param_groups: bool, default=True
        Whether to remove parameter groups recursively when setting the uniform distributions.
        Otherwise, the joint distribution of the module may not be the desired uniform distribution.
    zero_values: str, default='keep'
        How to treat zero values. Currently the only option is to keep them and their bounds as is.

    Returns
    -------
    Self
        A new module with updated parameter distributions.
    """        
    current_param_names = set(self.param_names(param_filter, **kwargs))

    def map_fn(path, param: Parameter):
        if is_valid_param(param):
            name = self.path_to_param_name(path)
            if name in current_param_names:
                value = jnp.asarray(param.value)

                # Calculate bounds element-wise natively in JAX
                base_min = jnp.where(value > 0.0, value * (1.0 - percentage), value * (1.0 + percentage))
                base_max = jnp.where(value > 0.0, value * (1.0 + percentage), value * (1.0 - percentage))

                if zero_values == 'keep':
                    new_min = jnp.where(value == 0.0, 0.0, base_min)
                    new_max = jnp.where(value == 0.0, 0.0, base_max)
                else:
                    raise Exception("Unknown option for 'zero_values'")

                if respect_bounds:
                    param_min = getattr(param, 'min', -jnp.inf) if param.bounds is None else param.bounds[0]
                    param_max = getattr(param, 'max', jnp.inf) if param.bounds is None else param.bounds[1]
                    new_min = jnp.maximum(new_min, param_min)
                    new_max = jnp.minimum(new_max, param_max)

                distribution = UniformDistribution(new_min, new_max)
                return param.with_distribution(distribution)
        return param

    new_module = jax.tree_util.tree_map_with_path(map_fn, self, is_leaf=is_valid_param)
    if remove_param_groups:
        new_module = new_module.with_no_param_groups()
    return new_module

parax.Operator

Bases: Module, Generic[OpInputs, OpOutputs]

A composable callable that applies some operation to input arguments.

Supports standard Python operator overloading to seamlessly compose operators into complex graphs.

parax.load(source)

Load a Parax PyTree (e.g., a Module, or a dict/list of Modules) from a file.

Parameters:

Name Type Description Default
source str, os.PathLike, or BinaryIO

The path to the saved file or an open file-like object containing the data.

required

Returns:

Type Description
Any

The deserialized PyTree (Module, dict, list, etc.).

Raises:

Type Description
TypeError

If the root object or any nested submodules fail to load and silently degrade into dictionaries (usually due to moved classes).

Source code in parax/serialization.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def load(source: str | os.PathLike | BinaryIO) -> Any:
    """
    Load a Parax PyTree (e.g., a Module, or a dict/list of Modules) from a file.

    Parameters
    ----------
    source : str, os.PathLike, or BinaryIO
        The path to the saved file or an open file-like object containing the data.

    Returns
    -------
    Any
        The deserialized PyTree (Module, dict, list, etc.).

    Raises
    ------
    TypeError
        If the root object or any nested submodules fail to load and 
        silently degrade into dictionaries (usually due to moved classes).
    """
    if isinstance(source, (str, os.PathLike)):
        with open(source, "r", encoding="utf8") as f:
            data = f.read()
    else:
        data = source.read()

    decoded = jsonpickle.decode(data)    

    # Recursively check for nested degraded objects across the entire PyTree
    def _verify_no_degraded_modules(obj, current_path="root"):
        if isinstance(obj, dict):
            # If a dict has 'py/object', jsonpickle failed to resolve the class path
            if 'py/object' in obj:
                failed_class = obj['py/object']
                raise TypeError(
                    f"Degraded object found at path '{current_path}'. "
                    f"Failed to instantiate the class '{failed_class}'"
                    "Did you move or rename this class in your codebase?"
                )
            for k, v in obj.items():
                _verify_no_degraded_modules(v, f"{current_path}[{repr(k)}]")

        elif isinstance(obj, (list, tuple)):
            for i, v in enumerate(obj):
                _verify_no_degraded_modules(v, f"{current_path}[{i}]")

        elif isinstance(obj, eqx.Module):
            # Safely traverse Equinox/Parax modules and dataclasses
            for f in obj.__dataclass_fields__:
                _verify_no_degraded_modules(getattr(obj, f), f"{current_path}.{f}")

    _verify_no_degraded_modules(decoded)

    return decoded

parax.save(target, tree)

Save a Parax PyTree (e.g., a Module, or a dict/list of Modules) to a file.

Parameters:

Name Type Description Default
target str, os.PathLike, or BinaryIO

The path to the saved file or an open file-like object.

required
tree Any

The PyTree containing Parax modules to save.

required
Source code in parax/serialization.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def save(target: str | os.PathLike | BinaryIO, tree: Any):
    """
    Save a Parax PyTree (e.g., a Module, or a dict/list of Modules) to a file.

    Parameters
    ----------
    target : str, os.PathLike, or BinaryIO
        The path to the saved file or an open file-like object.
    tree : Any
        The PyTree containing Parax modules to save.
    """
    # 1. Map over the PyTree, converting only Parax Modules into their saveable forms
    def to_saveable(node):
        if isinstance(node, Module):
            return node.saveable()
        return node

    tree_save = jtu.tree_map(
        to_saveable, 
        tree, 
        is_leaf=lambda x: isinstance(x, Module)
    )

    # 2. Encode the standardized tree
    data = jsonpickle.encode(tree_save, unpicklable=True)

    # 3. Write to file
    if isinstance(target, (str, os.PathLike)):
        with open(target, "w", encoding="utf8") as f:
            f.write(data)
    else:
        target.write(data)

parax.field

field(*, converter=None, static=False, save=True, transparent=False, **kwargs)

Custom field specifier for Parax modules.

Source code in parax/field.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def field(
    *,
    converter: Callable[[Any], Any] | None = None,
    static: bool = False,
    save: bool = True,
    transparent: bool = False,
    **kwargs: Any,
) -> Any:
    """Custom field specifier for Parax modules."""

    # Handle Parax-specific metadata
    metadata = dict(kwargs.pop("metadata", {}))
    if not save:
        metadata["save"] = False
    if transparent:
        metadata["transparent"] = True

    kwargs['metadata'] = metadata

    return eqx.field(
        converter=converter,
        static=static,
        **kwargs
    )

parax.parameters

Parameter factories with pre-defined probability distributions.

CenteredUniform(mean, half_width, *args, **kwargs)

Create a Parameter with a uniform distribution.

Parameters:

Name Type Description Default
mean float | Sequence[float]

The mean value of the distribution. Can be a sequence for a multi-valued Parameter.

required
half_width float | Sequence[float]

The half-width value of the distribution. Can be a sequence for a multi-valued Parameter.

required
**kwargs

Additional keyword arguments passed to parax.Uniform.

{}

Returns:

Type Description
Parameter

The created Parameter object.

Source code in parax/parameters.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def CenteredUniform(mean: float | Sequence[float], half_width: float | Sequence[float], *args, **kwargs) -> Parameter:
    r"""
    Create a `Parameter` with a uniform distribution.

    Parameters
    ----------
    mean : float | Sequence[float]
        The mean value of the distribution. Can be a sequence for a multi-valued Parameter.
    half_width : float | Sequence[float]
        The half-width value of the distribution. Can be a sequence for a multi-valued Parameter.
    **kwargs
        Additional keyword arguments passed to [`parax.Uniform`][].

    Returns
    -------
    Parameter
        The created Parameter object.
    """
    low = mean - half_width
    high = mean + half_width

    return Uniform(low, high, *args, **kwargs)

Fixed(value, **kwargs)

Create a Parameter that is marked as fixed.

This sets the fixed flag of the parameter to True.

Parameters:

Name Type Description Default
value

The value of the parameter.

required
**kwargs

Additional keyword arguments passed to the Parameter constructor.

{}

Returns:

Type Description
Parameter

The created fixed Parameter object.

Source code in parax/parameters.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
def Fixed(value, **kwargs) -> Parameter:
    r"""
    Create a `Parameter` that is marked as fixed.

    This sets the `fixed` flag of the parameter to `True`.

    Parameters
    ----------
    value
        The value of the parameter.
    **kwargs
        Additional keyword arguments passed to the `Parameter` constructor.

    Returns
    -------
    Parameter
        The created fixed Parameter object.
    """
    value = jnp.array(value, dtype=float)
    return Parameter(value=value, fixed=True, **kwargs)

Free(value, **kwargs)

Create a Parameter that is marked as free (i.e., free to vary).

This sets the fixed flag of the parameter to False.

Parameters:

Name Type Description Default
value

The value of the parameter.

required
n int

The number of identical parameters to create in an array. Defaults to None.

required
**kwargs

Additional keyword arguments passed to the Parameter constructor.

{}

Returns:

Type Description
Parameter

The created free Parameter object.

Source code in parax/parameters.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def Free(value, **kwargs) -> Parameter:
    r"""
    Create a `Parameter` that is marked as free (i.e., free to vary).

    This sets the `fixed` flag of the parameter to `False`.

    Parameters
    ----------
    value
        The value of the parameter.
    n : int, optional
        The number of identical parameters to create in an array. Defaults to None.
    **kwargs
        Additional keyword arguments passed to the `Parameter` constructor.

    Returns
    -------
    Parameter
        The created free Parameter object.
    """
    value = jnp.array(value, dtype=float)
    return Parameter(value=value, fixed=False, **kwargs)

Normal(mean, std, value=None, **kwargs)

Create a Parameter with a normal (Gaussian) distribution.

Parameters:

Name Type Description Default
mean float | Sequence[float]

The mean of the distribution. Can be a sequence for a multi-valued Parameter.

required
std float | Sequence[float]

The standard deviation of the distribution. Can be a sequence for a multi-valued Parameter.

required
value optional

The initial value. If None, the mean of the distribution is used. Defaults to None.

None
**kwargs

Additional keyword arguments forward to the Parameter constructor.

{}

Returns:

Type Description
Parameter

The created Parameter object.

Source code in parax/parameters.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def Normal(mean: float | Sequence[float], std: float | Sequence[float], value=None, **kwargs) -> Parameter:
    r"""
    Create a `Parameter` with a normal (Gaussian) distribution.

    Parameters
    ----------
    mean : float | Sequence[float]
        The mean of the distribution. Can be a sequence for a multi-valued Parameter.
    std : float | Sequence[float]
        The standard deviation of the distribution. Can be a sequence for a multi-valued Parameter.
    value : optional
        The initial value. If None, the mean of the distribution is used. Defaults to None.
    **kwargs
        Additional keyword arguments forward to the `Parameter` constructor.

    Returns
    -------
    Parameter
        The created Parameter object.
    """
    mean, std = jnp.array(mean, dtype=float), jnp.array(std, dtype=float)
    dists = dist.Normal(mean, std)
    values = mean if value is None else value
    return Parameter(value=values, distribution=dists, **kwargs)

RelativeNormal(mean, std_fraction, **kwargs)

Create a Parameter with a normal distribution defined by a relative standard deviation.

The scale (sigma) is calculated as: mean * std_fraction

Parameters:

Name Type Description Default
mean float | Sequence[float]

The center (mean) of the distribution.

required
std_fraction float | Sequence[float]

The standard deviation expressed as a fraction of the mean (also known as the coefficient of variation). e.g., 0.1 results in a distribution with sigma = 0.1 * mean.

required
**kwargs

Additional keyword arguments passed to parax.Normal.

{}

Returns:

Type Description
Parameter
Source code in parax/parameters.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def RelativeNormal(mean: float | Sequence[float], std_fraction: float | Sequence[float], **kwargs) -> Parameter:
    r"""
    Create a `Parameter` with a normal distribution defined by a relative standard deviation.

    The scale (sigma) is calculated as: `mean * std_fraction`

    Parameters
    ----------
    mean : float | Sequence[float]
        The center (mean) of the distribution.
    std_fraction : float | Sequence[float]
        The standard deviation expressed as a fraction of the mean 
        (also known as the coefficient of variation).
        e.g., 0.1 results in a distribution with sigma = 0.1 * mean.
    **kwargs
        Additional keyword arguments passed to [`parax.Normal`][].

    Returns
    -------
    Parameter
    """
    mean_arr = jnp.array(mean, dtype=float)
    frac_arr = jnp.array(std_fraction, dtype=float)

    # Calculate absolute standard deviation
    sigma = jnp.abs(mean_arr * frac_arr)

    return Normal(mean=mean_arr, std=sigma, **kwargs)

RelativeUniform(mean, deviation_fraction, *args, **kwargs)

Create a Parameter with a uniform distribution defined by a fractional deviation.

The bounds are calculated as: mean * (1 +/- deviation_fraction)

Parameters:

Name Type Description Default
mean float | Sequence[float]

The center (mean) of the distribution.

required
deviation_fraction float | Sequence[float]

The relative radius of the distribution bounds as a fraction of the mean. e.g., 0.1 results in bounds of [0.9 * mean, 1.1 * mean].

required
**kwargs

Additional keyword arguments passed to parax.Uniform.

{}

Returns:

Type Description
Parameter
Source code in parax/parameters.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def RelativeUniform(mean: float | Sequence[float], deviation_fraction: float | Sequence[float], *args, **kwargs) -> Parameter:
    r"""
    Create a `Parameter` with a uniform distribution defined by a fractional deviation.

    The bounds are calculated as: `mean * (1 +/- deviation_fraction)`

    Parameters
    ----------
    mean : float | Sequence[float]
        The center (mean) of the distribution.
    deviation_fraction : float | Sequence[float]
        The relative radius of the distribution bounds as a fraction of the mean.
        e.g., 0.1 results in bounds of [0.9 * mean, 1.1 * mean].
    **kwargs
        Additional keyword arguments passed to [`parax.Uniform`][].

    Returns
    -------
    Parameter
    """
    mean_arr = jnp.array(mean, dtype=float)
    frac_arr = jnp.array(deviation_fraction, dtype=float)

    # Calculate the absolute deviation (radius)
    delta = jnp.abs(mean_arr * frac_arr)

    return Uniform(mean_arr - delta, mean_arr + delta, *args, **kwargs)

Uniform(low, high, value=None, **kwargs)

Create a Parameter with a uniform distribution.

Parameters:

Name Type Description Default
low float | Sequence[float]

The lower value of the distribution. Can be a sequence for a multi-valued Parameter.

required
high float | Sequence[float]

The upper value of the distribution. Can be a sequence for a multi-valued Parameter.

required
value optional

The initial value. If None, the midpoint of the distribution is used. Defaults to None.

None
**kwargs

Additional keyword arguments passed to the Parameter constructor.

{}

Returns:

Type Description
Parameter

The created Parameter object.

Source code in parax/parameters.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def Uniform(low: float | Sequence[float], high: float | Sequence[float], value=None, **kwargs) -> Parameter:
    r"""
    Create a `Parameter` with a uniform distribution.

    Parameters
    ----------
    low : float | Sequence[float]
        The lower value of the distribution. Can be a sequence for a multi-valued Parameter.
    high : float | Sequence[float]
        The upper value of the distribution. Can be a sequence for a multi-valued Parameter.
    value : optional
        The initial value. If None, the midpoint of the distribution is used. Defaults to None.
    **kwargs
        Additional keyword arguments passed to the `Parameter` constructor.

    Returns
    -------
    Parameter
        The created Parameter object.
    """
    low, high = jnp.array(low, dtype=float), jnp.array(high, dtype=float)
    dists = dist.Uniform(low, high)
    values = (low + high) / 2.0 if value is None else value
    return Parameter(value=values, distribution=dists, **kwargs)

parax.op

Binary

Bases: Operator[OpInputs, OpOutputs]

Returns the result of a callable that accepts the result of two operators.

The functional callable fn must have the signature f(left, right).

Constant

Bases: Operator[OpInputs, OpOutputs]

Returns a fixed constant array or scalar.

Derivative

Bases: Operator[OpInputs, OpOutputs]

Computes numerical derivative with respect to a context attribute.

Diagonal

Bases: Operator[OpInputs, OpOutputs]

Extracts the diagonals of matrices.

Flatness

Bases: Derivative

Enforces gain flatness by computing the first derivative.

Index

Bases: Operator[OpInputs, OpOutputs]

Slices or indexes the output of another operator.

Lambda

Bases: Operator[OpInputs, OpOutputs]

Wraps a standard Python or JAX callable with the same domain as the operator.

Map

Bases: Operator[OpInputs, OpOutputs]

Applies an arbitrary function to a single operator's output.

Mask

Bases: Operator[OpInputs, OpOutputs]

Applies a boolean mask to the output of an operator.

Method

Bases: Operator[OpInputs, OpOutputs]

Dynamically accesses and executes a method on the first argument.

OffDiagonal(operator, n_ports, **kwargs)

Bases: Mask

Extracts off-diagonal elements.

Source code in parax/op/math.py
78
79
80
81
def __init__(self, operator: Operator, n_ports: int, **kwargs):
    mask = ~jnp.eye(n_ports, dtype=bool)
    # We initialize the parent Mask class with the generated eye mask
    super().__init__(operator=operator, mask=mask, **kwargs)

Reduce

Bases: Operator[OpInputs, OpOutputs]

Applies a reduction (e.g., jnp.max, jnp.mean) over a specific axis.

Stack

Bases: Operator[OpInputs, OpOutputs]

Stacks the results of multiple operators along an axis.

Sum

Bases: Reduce

Convenience class for summing an operator's output.

Where

Bases: Operator[OpInputs, OpOutputs]

A conditional branching node using jax.lax.cond.

Evaluates a boolean condition (from an Operator) and returns the output of either true_branch or false_branch depending on the condition.