field (pmrf.field)

pmrf.field(*, converter: Callable[[Any], Any] | None = None, static: bool = False, **kwargs: Any) Any

Equinox supports extra functionality on top of the default dataclasses.

Arguments:

  • converter: a function to call on this field when the model is initialised. For

    example, field(converter=jax.numpy.asarray) to convert bool/int/float/complex values to JAX arrays. This is ran after the __init__ method (i.e. when using a user-provided __init__), and after __post_init__ (i.e. when using the default dataclass initialisation). If converter is None, then no converter is registered.

  • static: whether the field should not interact with any JAX transform at all (by

    making it part of the PyTree structure rather than a leaf).

  • **kwargs: All other keyword arguments are passed on to dataclass.field.

!!! example “Example for converter

```python class MyModule(eqx.Module):

foo: Array = eqx.field(converter=jax.numpy.asarray)

mymodule = MyModule(1.0) assert isinstance(mymodule.foo, jax.Array) ```

!!! example “Example for static

```python class MyModule(eqx.Module):

normal_field: int static_field: int = eqx.field(static=True)

mymodule = MyModule(“normal”, “static”) leaves, treedef = jax.tree_util.tree_flatten(mymodule) assert leaves == [“normal”] assert “static” in str(treedef) ```

static=True means that this field is not a node of the PyTree, so it does not interact with any JAX transforms, like JIT or grad. This means that it is usually a bug to make JAX arrays be static fields. static=True should very rarely be used. It is preferred to just filter out each field with eqx.partition whenever you need to select only some fields.