Partial (pmrf.Partial)

class pmrf.Partial(func: Callable[[...], _Return], /, *args: Any, **kwargs: Any)

Bases: Module, Generic[_Return]

Like functools.partial, but treats the wrapped function, and partially-applied args and kwargs, as a PyTree.

This is very much like jax.tree_util.Partial. The difference is that the JAX version requires that func be specifically a function – and will silently misbehave if given any non-function callable, e.g. [equinox.nn.MLP][]. In contrast the Equinox version allows for arbitrary callables.

Arguments:

  • func: the callable to partially apply.

  • *args: any positional arguments to apply.

  • **kwargs: any keyword arguments to apply.

__call__(*args: Any, **kwargs: Any) _Return

Call the wrapped self.func.

Arguments:

  • *args: the arguments to apply. Passed after those arguments passed during

    __init__.

  • **kwargs: any keyword arguments to apply.

Returns:

The result of the wrapped function.

args: tuple[Any, ...]
func: Callable[[...], _Return]
keywords: dict[str, Any]