Partial (pmrf.Partial)

class pmrf.Partial(func: Callable[[...], _Return], /, *args: Any, **kwargs: Any)

Bases: Module, Generic[_Return]

(experimental) Like functools.partial, but JAX-compatible.

This implementation flattens nested partials and prioritizes newer keyword arguments, mirroring the behavior of functools.partial.

Parameters:
  • func (Callable[..., _Return]) – The callable to partially apply.

  • *args (Any) – Positional arguments to bind.

  • **kwargs (Any) – Keyword arguments to bind.

Variables:
  • func (Callable[..., _Return]) – The unwrapped underlying callable.

  • args (tuple[Any, ...]) – The bound positional arguments.

  • keywords (dict[str, Any]) – The bound keyword arguments.

__call__(*args: Any, **kwargs: Any) _Return

Invoke the wrapped callable with bound and newly supplied arguments.

Parameters:
  • *args (Any) – Additional positional arguments to append.

  • **kwargs (Any) – Additional keyword arguments. These override any bound keywords.

Returns:

The result of the wrapped callable.

Return type:

_Return

args: tuple[Any, ...]
func: Callable[[...], _Return]
keywords: dict[str, Any]