Partial (pmrf.Partial)
- class pmrf.Partial(func: Callable[[...], _Return], /, *args: Any, **kwargs: Any)
Bases:
Module,Generic[_Return]Like functools.partial, but treats the wrapped function, and partially-applied args and kwargs, as a PyTree.
This is very much like jax.tree_util.Partial. The difference is that the JAX version requires that func be specifically a function – and will silently misbehave if given any non-function callable, e.g. [equinox.nn.MLP][]. In contrast the Equinox version allows for arbitrary callables.
Arguments:
func: the callable to partially apply.
*args: any positional arguments to apply.
**kwargs: any keyword arguments to apply.
- __call__(*args: Any, **kwargs: Any) _Return
Call the wrapped self.func.
Arguments:
- *args: the arguments to apply. Passed after those arguments passed during
__init__.
**kwargs: any keyword arguments to apply.
Returns:
The result of the wrapped function.
- args: tuple[Any, ...]
- func: Callable[[...], _Return]
- keywords: dict[str, Any]