espnet3.parallel.base_runner.BaseRunner
About 1 min
espnet3.parallel.base_runner.BaseRunner
class espnet3.parallel.base_runner.BaseRunner(provider: EnvironmentProvider, , async_mode: bool = False, async_specs_dir: str | Path = './_async_specs', async_result_dir: str | Path = './_async_results')
Bases: ABC
A thin orchestration layer to run static forward over indices.
This class handles: : - Switching among local, parallel, and asynchronous (distributed) modes.
- Injecting per-worker environments supplied by an
EnvironmentProvider. - Keeping
forwardas a@staticmethodfor pickle-safety.
Subclass contract: : - Implement @staticmethod forward(idx, *, dataset, model, **env) -> Any without capturing self.
- Provide an
EnvironmentProviderthat builds the required env (e.g., dataset/model) for local and worker executions.
- Parameters:
- provider (EnvironmentProvider) β Provider that builds the runtime env.
- async_mode (bool) β If True, use Dask
submitwith asynchronous shards. - async_specs_dir (str | Path) β Output directory for per-shard spec JSON files.
- async_num_workers (int | None) β If set, overrides detected worker count to decide how many shards to create.
- async_result_dir (str | Path) β Output directory for per-shard JSONL results.
Notes
- In parallel sync mode (when a Dask cluster is configured), tasks are submitted via
parallel_mapand results are gathered in order. - In async mode, the results will be written on async_result_dir.
abstractmethod static forward(idx: int, , dataset, model, **env) β Any
Compute one item for the given index (to be implemented by subclasses).
Keep this as a @staticmethod so that it is pickle-safe for Dask and does not capture self.
- Parameters:
- idx (int) β The input index to process.
- dataset β Dataset object provided via the environment.
- model β Model object provided via the environment.
- **env β Any additional environment entries injected by the provider.
- Returns: Result for the given index.
- Return type: Any
- Raises:NotImplementedError β Always in the base class; implement in subclass.
Example
>>> class MyRunner(BaseRunner):
... @staticmethod
... def forward(idx, *, dataset, model, **env):
... x = dataset[idx]
... return model(x)