espnet3.parallel.parallel.wrap_func_with_worker_env
Less than 1 minute
espnet3.parallel.parallel.wrap_func_with_worker_env
espnet3.parallel.parallel.wrap_func_with_worker_env(func: Callable) → Callable
Wrap a user-defined function for a WorkerPlugin.
This wrapper inspects the function’s signature and automatically supplies keyword arguments from the worker’s environment (worker.plugins[“env”]) when they match parameter names of the function and are not explicitly provided by the caller.
Conflict detection: : If both the worker environment and the call’s kwargs provide the same argument name, the wrapper raises a ValueError before calling the underlying function.
- Parameters:func (Callable) – The original user-defined function to be executed on the worker. It may have positional parameters, keyword parameters, and/or a
**kwargscatch-all. - Returns: A wrapped function that: : 1. Runs on the worker. 2. Retrieves the environment dict from
worker.plugins["env"]. 3. Detects and errors on conflicts with explicit keyword arguments. 4. Supplies any missing keyword arguments from the environment. - Return type: Callable
- Raises:ValueError – If there is at least one parameter name that is present both in the worker environment and in the keyword arguments provided to the call.
Notes
- Only environment keys that match the function’s parameter names (or any keys if the function accepts
**kwargs) will be considered for injection.
Example
>>> def setup_fn():
... return {"bias": 7}
...
>>> def add_bias(x, bias):
... return x + bias
...
>>> with get_client(local_config, setup_fn=setup_fn) as client:
... # 'bias' comes from worker env, no need to pass it explicitly
... futs = client.map(add_bias, [1, 2])
... print(client.gather(futs))
[8, 9]>>> # Passing conflicting 'bias' both in env and kwargs will error:
>>> with pytest.raises(ValueError):
... client.map(add_bias, [1, 2], bias=5)