espnet3.parallel.parallel.wrap_func_with_worker_env
About 1 min
espnet3.parallel.parallel.wrap_func_with_worker_env
espnet3.parallel.parallel.wrap_func_with_worker_env(func: Callable) ā Callable
Wrap a user-defined function so that it can transparently consume per-worker environment variables registered via a WorkerPlugin.
This wrapper inspects the functionās signature and automatically supplies keyword arguments from the workerās environment (worker.plugins[āenvā]) when they match parameter names of the function and are not explicitly provided by the caller.
Conflict detection: : If both the worker environment and the callās kwargs provide the same argument name, the wrapper raises a ValueError before calling the underlying function.
- Parameters:func (Callable) ā The original user-defined function to be executed on the worker. It may have positional parameters, keyword parameters, and/or a
**kwargscatch-all. - Returns: A wrapped function that: : 1. Runs on the worker. 2. Retrieves the environment dict from
worker.plugins["env"]. 3. Detects and errors on conflicts with explicit keyword arguments. 4. Supplies any missing keyword arguments from the environment. - Return type: Callable
- Raises:ValueError ā If there is at least one parameter name that is present both in the worker environment and in the keyword arguments provided to the call.
Notes
- Only environment keys that match the functionās parameter names (or any keys if the function accepts
**kwargs) will be considered for injection.
Example
>>> def setup_fn():
... return {"bias": 7}
...
>>> def add_bias(x, bias):
... return x + bias
...
>>> with get_client(local_cfg, setup_fn=setup_fn) as client:
... # 'bias' comes from worker env, no need to pass it explicitly
... futs = client.map(add_bias, [1, 2])
... print(client.gather(futs))
[8, 9]>>> # Passing conflicting 'bias' both in env and kwargs will error:
>>> with pytest.raises(ValueError):
... client.map(add_bias, [1, 2], bias=5)