espnet3.parallel.base_runner.BaseRunner
espnet3.parallel.base_runner.BaseRunner
class espnet3.parallel.base_runner.BaseRunner(provider: EnvironmentProvider, batch_size: int | None = None, output_dir: str | Path | None = None, shard_subdir: str = '', resume: bool = True)
Bases: ABC
A thin orchestration layer to run static forward over shard items.
This class handles: : - Switching among local and parallel execution modes.
- Injecting per-worker environments supplied by an
EnvironmentProvider. - Keeping
forwardas a@staticmethodfor pickle-safety. - Persisting shard-local state so interrupted runs can resume.
- Optionally reducing per-shard results on the worker before they are merged on the driver.
This class is the durable parallel execution path in ESPnet3. It writes shard-local files under output_dir and supports resume by reusing shard state when a previous run completed some splits.
Subclass contract: : - Implement @staticmethod forward(idx, dataset, model, **env) -> Any without capturing self. idx may be a single index or a batch of indices depending on batch_size.
- Provide an
EnvironmentProviderthat builds the required env (e.g., dataset/model) for local and worker executions. - Override
open_writers/write_record/close_writers/merge_state/merge_shard_fileswhen reducer-style execution is needed.
- Parameters:
- provider (EnvironmentProvider) – Provider that builds the runtime env.
- batch_size (int | None) – If set, chunk indices into batches of this size before dispatching to
forward. - output_dir (str | Path | None) – Root output directory used by reducer-style runners to write shard-local files and merged outputs. This must be provided before calling the runner.
- shard_subdir (str) – Optional sub-path under
output_dirwhere shard directories and manifest are written. - resume (bool) – Whether to skip shards that already have a
donemarker.
Notes
- When
output_diris set, shard state is written underoutput_dir / shard_subdir / "split.N". - A shard is considered complete when its directory contains a
donemarker.
Example
>>> from espnet3.parallel.env_provider import EnvironmentProvider
>>> from omegaconf import OmegaConf
>>> class MyProvider(EnvironmentProvider):
... def build_env_local(self):
... return {"dataset": ..., "model": ...}
... def build_worker_setup_fn(self):
... def setup():
... return {"dataset": ..., "model": ...}
... return setup
>>> class MyRunner(BaseRunner):
... @staticmethod
... def forward(idx, dataset, model, **env):
... return model(dataset[idx])
>>> provider = MyProvider(OmegaConf.create({}))
>>> runner = MyRunner(provider, output_dir="/tmp/out")
>>> runner(range(100))Initialize BaseRunner object.
static close_writers(writers: Dict[str, Any]) → Dict[str, Any] | None
Close per-shard writers after all items are processed.
Called once at the end of each shard. Override to flush and close any file handles opened in open_writers. The optional return value is merged into the shard state before it is persisted.
- Parameters:writers – The dict returned by
open_writersfor this shard. - Returns: Extra entries to merge into the shard state, or
None. - Return type: Optional[Dict[str, Any]]
classmethod finalize_state(state: Dict[str, Any], **env) → Dict[str, Any]
Close writers and finalize the shard state.
- Parameters:
- state – Shard state containing
_writersand accumulated data. - **env – Full worker environment.
- state – Shard state containing
- Returns: Finalized state with
_writersremoved and any extra metadata fromclose_writersmerged in. - Return type: Dict[str, Any]
abstractmethod static forward(idx: int | Iterable[int], dataset, model, **env) → Any
Compute items for the given index or batch (to be implemented by subclasses).
Keep this as a @staticmethod so that it is pickle-safe for Dask and does not capture self.
- Parameters:
- idx (int | Iterable *[*int ]) – The input index or batch of indices to process.
- dataset – Dataset object provided via the environment.
- model – Model object provided via the environment.
- **env – Any additional environment entries injected by the provider.
- Returns: Result for the given index or batch.
- Return type: Any
- Raises:NotImplementedError – Always in the base class; implement in subclass.
Example
>>> class MyRunner(BaseRunner):
... @staticmethod
... def forward(idx, dataset, model, **env):
... if isinstance(idx, int):
... x = dataset[idx]
... return model(x)
... xs = [dataset[i] for i in idx]
... return model(xs)classmethod init_state(shard_id: int = 0, output_dir: str = '', shard_subdir: str = '', **env) → Dict[str, Any]
Build the initial state dict for one shard and open its writers.
- Parameters:
- shard_id – Zero-based shard index.
- output_dir – Root output directory (string form).
- shard_subdir – Optional sub-path appended to
output_dir. - **env – Full worker environment forwarded to
open_writers.
- Returns: Initial state with keys
shard_id,shard_dir,_writers, andrecords. - Return type: Dict[str, Any]
Example
>>> state = MyRunner.init_state(
... shard_id=0, output_dir="/tmp/out", dataset=ds, model=md
... )
>>> print(state["shard_dir"])
/tmp/out/split.0classmethod is_shard_done(shard_dir: Path) → bool
Return True if the shard has a completion marker file.
merge(shard_dirs: List[Path]) → Any
Merge completed shard outputs into the final result.
Called on the driver after all shards finish. Override to aggregate per-shard files (e.g., SCP files, stats arrays) into a single output.
- Parameters:shard_dirs – Ordered list of completed shard directories.
- Returns: Aggregated result, or
Noneif outputs are written to disk and no in-memory result is needed. - Return type: Any
static open_writers(shard_dir: Path | None, **env) → Dict[str, Any]
Open per-shard file writers before processing begins.
Called once at the start of each shard. Override to open output file handles or other resources that accumulate results across multiple forward calls within a shard.
- Parameters:
- shard_dir – Directory dedicated to this shard’s output files.
- **env – Full worker environment (dataset, model, etc.).
- Returns: A writers dict passed to every
write_recordandclose_writerscall for this shard. - Return type: Dict[str, Any]
classmethod reduce_state(state: Dict[str, Any], result: Any, **env) → Dict[str, Any]
Fold a single forward result into the shard state.
- Parameters:
- state – Current shard state (mutated in place via
write_record). - result – Value returned by
forwardfor one item or batch. - **env – Full worker environment.
- state – Current shard state (mutated in place via
- Returns: Updated shard state.
- Return type: Dict[str, Any]
static write_record(writers: Dict[str, Any], result: Any, state: Dict[str, Any], **env) → None
Persist one forward result into the shard state or files.
Called after each forward invocation. Override to stream results into open file handles instead of accumulating them in memory.
- Parameters:
- writers – The dict returned by
open_writersfor this shard. - result – The value returned by
forwardfor this item. - state – Mutable shard state dict (
records,shard_id, etc.). - **env – Full worker environment.
- writers – The dict returned by
