espnet3.parallel.base_runner.BaseRunner
espnet3.parallel.base_runner.BaseRunner
class espnet3.parallel.base_runner.BaseRunner(provider: EnvironmentProvider, , batch_size: int | None = None, async_mode: bool = False, async_specs_dir: str | Path = './_async_specs', async_result_dir: str | Path = './_async_results')
Bases: ABC
A thin orchestration layer to run static forward over indices.
This class handles: : - Switching among local, parallel, and asynchronous (distributed) modes.
- Injecting per-worker environments supplied by an
EnvironmentProvider. - Keeping
forwardas a@staticmethodfor pickle-safety.
Subclass contract: : - Implement @staticmethod forward(idx, *, dataset, model, **env) -> Any without capturing self. idx may be a single index or a batch of indices depending on batch_size.
- Provide an
EnvironmentProviderthat builds the required env (e.g., dataset/model) for local and worker executions.
- Parameters:
- provider (EnvironmentProvider) – Provider that builds the runtime env.
- batch_size (int | None) – If set, chunk indices into batches of this size before dispatching to
forward. - async_mode (bool) – If True, use Dask
submitwith asynchronous shards. - async_specs_dir (str | Path) – Output directory for per-shard spec JSON files.
- async_num_workers (int | None) – If set, overrides detected worker count to decide how many shards to create.
- async_result_dir (str | Path) – Output directory for per-shard JSONL results.
Notes
- In parallel sync mode (when a Dask cluster is configured), tasks are submitted via
parallel_mapand results are gathered in order. - In async mode, the results will be written on async_result_dir.
Initialize BaseRunner object.
classmethod batch_forward(indices: Iterable[int], , dataset, model, **env) → Any
Compute a batch by delegating to forward per index as a default.
This should be overridden by subclasses that can handle batched inputs.
- Parameters:
- indices (Iterable *[*int ]) – Indices to process as a batch.
- dataset – Dataset object provided via the environment.
- model – Model object provided via the environment.
- **env – Any additional environment entries injected by the provider.
- Returns: Batch result from the runner.
- Return type: Any
abstractmethod static forward(idx: int | Iterable[int], , dataset, model, **env) → Any
Compute items for the given index or batch (to be implemented by subclasses).
Keep this as a @staticmethod so that it is pickle-safe for Dask and does not capture self.
- Parameters:
- idx (int | Iterable *[*int ]) – The input index or batch of indices to process.
- dataset – Dataset object provided via the environment.
- model – Model object provided via the environment.
- **env – Any additional environment entries injected by the provider.
- Returns: Result for the given index or batch.
- Return type: Any
- Raises:NotImplementedError – Always in the base class; implement in subclass.
Example
>>> class MyRunner(BaseRunner):
... @staticmethod
... def forward(idx, *, dataset, model, **env):
... if isinstance(idx, int):
... x = dataset[idx]
... return model(x)
... xs = [dataset[i] for i in idx]
... return model(xs)