ESPnet3 Provider And Runner
ESPnet3 Provider And Runner
This page is for people who want to implement or modify a parallel workload.
Start with ESPnet3 Parallel for the high-level flow. This page focuses on:
- subclass contracts
- writer hooks
- shard-local files
- implementation snippets
What this page covers
EnvironmentProvider and BaseRunner are the two main extension points.
EnvironmentProviderdecides how runtime objects are builtBaseRunnerdecides how one shard is processed and how outputs are written
Use this page when you are:
- adding a new parallel inference-like task
- changing shard output formats
- debugging writer / merge behavior
- deciding which hook to override
Read the real contracts first
Read the generated API docs before changing these classes:
Useful concrete examples:
EnvironmentProvider
EnvironmentProvider has exactly two required methods:
build_env_local()
Use this for local execution.
It should build one env dict on the driver and return it directly.
from espnet3.parallel.env_provider import EnvironmentProvider
class MyProvider(EnvironmentProvider):
def build_env_local(self):
return {
"dataset": build_dataset(self.config),
"model": build_model(self.config),
"tokenizer": build_tokenizer(self.config),
}build_worker_setup_fn()
Use this for distributed execution.
It must return a zero-argument function. That returned function runs once per worker.
class MyProvider(EnvironmentProvider):
def build_worker_setup_fn(self):
config = self.config
def setup():
return {
"dataset": build_dataset(config),
"model": build_model(config),
"tokenizer": build_tokenizer(config),
}
return setupLocal vs worker timing
Keep this rule in mind:
- local mode: build once on the driver
- Dask / SLURM mode: build once per worker
Do not build large objects inside forward().
InferenceProvider
If your task is inference-like, check espnet3/parallel/inference_provider.py.
It prebuilds the local env once:
class InferenceProvider(EnvironmentProvider, ABC):
def __init__(self, config, params=None):
super().__init__(config)
self.params = params or {}
self._local_env = self.build_worker_setup_fn()()
def build_env_local(self):
return dict(self._local_env)That pattern is useful when:
- local mode should avoid repeated model loading
- worker setup logic and local setup logic should stay identical
BaseRunner
BaseRunner handles:
- batching indices
- shard planning
- resume
- local vs Dask dispatch
- shard-local writer lifecycle
- final merge
The main method you must implement is forward(...).
Minimal runner
from espnet3.parallel.base_runner import BaseRunner
class MyRunner(BaseRunner):
@staticmethod
def forward(idx, dataset, model, **env):
sample = dataset[idx]
return model(sample)Important constraints:
- keep
forward()as@staticmethod - do not capture
self - env keys are injected by name
Name-based env injection
If the provider returns:
{
"dataset": dataset,
"model": model,
"device": device,
}then forward() can declare:
@staticmethod
def forward(idx, dataset, model, device, **env):
...The parameter names must match the env dict keys.
Batch-aware forward()
If batch_size is set on the runner, idx may be a batch.
Write forward() so both forms are valid when needed.
@staticmethod
def forward(idx, dataset, model, **env):
if isinstance(idx, int):
return model(dataset[idx])
batch = [dataset[i] for i in idx]
return model(batch)Which hook should you override?
Use this rule:
- only compute one result in memory: override
forward() - write shard-local files incrementally: override writer hooks
- combine shard files on the driver: override
merge()
The main hooks are documented in the BaseRunner API reference:
open_writers(shard_dir, **env)write_record(writers, result, state, **env)close_writers(writers)merge(shard_dirs)
Lower-level state hooks also exist:
Most subclasses should not override those lower-level methods first.
Writer lifecycle
One shard roughly runs like this:
state = cls.init_state(shard_id=shard_id, **env)
for item in items:
result = cls.forward(item, **env)
state = cls.reduce_state(state, result, shard_id=shard_id, **env)
cls.finalize_state(state, shard_id=shard_id, **env)And by default:
init_state()createssplit.N/open_writers()returns a writer dictwrite_record()appends tostate["records"]close_writers()closes handles- a
donefile is written after successful completion
Minimal file-writing runner
This is the smallest useful pattern when results should be streamed to disk.
from pathlib import Path
from espnet3.parallel.base_runner import BaseRunner
class MyTextRunner(BaseRunner):
@staticmethod
def forward(idx, dataset, model, **env):
sample = dataset[idx]
hyp = model(sample)
return {"utt_id": sample["utt_id"], "text": hyp}
@staticmethod
def open_writers(shard_dir: Path, **env):
return {
"text": (shard_dir / "text").open("w", encoding="utf-8"),
}
@staticmethod
def write_record(writers, result, state, **env):
writers["text"].write(f'{result["utt_id"]} {result["text"]}\n')
@staticmethod
def close_writers(writers):
for handle in writers.values():
handle.close()
return None
def merge(self, shard_dirs):
out_dir = self.output_dir / self.shard_subdir if self.shard_subdir else self.output_dir
out_dir.mkdir(parents=True, exist_ok=True)
with (out_dir / "text").open("w", encoding="utf-8") as out_f:
for shard_dir in sorted(shard_dirs):
shard_path = shard_dir / "text"
if not shard_path.exists():
continue
out_f.write(shard_path.read_text(encoding="utf-8"))
return {}State-accumulating runner
If outputs are small, you may not need writers.
The default write_record() already appends each result into state["records"].
class MyCollectRunner(BaseRunner):
@staticmethod
def forward(idx, dataset, model, **env):
return {"idx": idx, "score": float(model(dataset[idx]))}
def merge(self, shard_dirs):
# read shard-local state or ignore merge if caller only needs side effects
return NoneIf you keep everything in memory, check carefully whether that still scales for your dataset size.
Real example: InferenceRunner
espnet3/systems/base/inference_runner.py is the best reference for writer-style parallel output.
Key ideas from that implementation:
open_writers()prepares shard-local SCP metadatawrite_record()validates one result and writes<field>.scpclose_writers()closes handles and writesfield_keys.txtmerge()concatenates shard-local SCP fragments into final outputs
The write path looks like this:
@staticmethod
def open_writers(shard_dir, output_artifacts=None, **env):
return {
"shard_dir": shard_dir,
"artifact_configs": output_artifacts or {},
"scp_handles": {},
"field_keys": set(),
}@staticmethod
def write_record(writers, result, state, idx_key="utt_id", **env):
for output in _iter_outputs(result):
idx_value = output[idx_key]
for field_key in field_keys:
handle = writers["scp_handles"].get(field_key)
if handle is None:
handle = (writers["shard_dir"] / f"{field_key}.scp").open(
"w", encoding="utf-8"
)
writers["scp_handles"][field_key] = handle
handle.write(f"{idx_value} {value}\n")And merge is just ordered shard-file concatenation:
for field_key in field_keys:
concatenate_shard_files(
ordered_shard_dirs,
f"{field_key}.scp",
base_dir / f"{field_key}.scp",
)See concatenate_shard_files() for the exact file merge behavior.
That pattern is the right choice when:
- each result becomes one or more output files
- per-shard streaming is cheaper than large Python lists
- final outputs should look like normal ESPnet artifacts
Output directory behavior
If output_dir is set, shard-local work is written under:
output_dir/
shard_subdir/
manifest.json
split.0/
split.1/
...The done marker is:
split.N/doneResume behavior depends on that file.
If resume=True, completed shards are skipped.
Common implementation patterns
Pattern 1: plain local computation
Use:
- simple provider
forward()only- no writer hooks
Good for:
- debugging
- small outputs
- tests
Pattern 2: inference output writing
Use:
- provider that builds dataset/model
forward()returning normalized dictsopen_writers()/write_record()/close_writers()merge()that assembles final files
Good for:
- SCP outputs
- JSONL fragments
- per-utterance artifacts
Pattern 3: worker-local initialization
Use:
build_worker_setup_fn()to construct heavy objects on each worker- env injection by name into
forward()
Good for:
- GPU models
- datasets with file handles
- large tokenizer/model objects
Common mistakes
Capturing self in forward()
Do not do this:
class BadRunner(BaseRunner):
def forward(self, idx): # wrong
...Use:
class GoodRunner(BaseRunner):
@staticmethod
def forward(idx, dataset, model, **env):
...Rebuilding the model inside forward()
Do not load a checkpoint per item.
Build it in the provider.
Returning the env instead of a setup function
Wrong:
def build_worker_setup_fn(self):
return {"dataset": ..., "model": ...}Correct:
def build_worker_setup_fn(self):
def setup():
return {"dataset": ..., "model": ...}
return setupUsing mismatched env names
Wrong:
return {"ds": dataset, "net": model}with
def forward(idx, dataset, model, **env):
...Correct the names or read from **env explicitly.
Forgetting shard merge semantics
If your subclass writes shard-local files, but merge() does nothing, the final outputs stay split across split.N/.
That may be fine for debugging. It is usually wrong for stage-facing outputs.
Practical debugging checklist
When a new runner does not behave correctly, check these first:
forward()is@staticmethod- provider env keys match
forward()parameter names output_diris set when using writer hooks- shard directories contain expected files
doneis written only after successful completionmerge()reads shards in stable orderresume=Trueis not hiding stale shard outputs during debugging
See also
ESPnet3 Parallel
Return to the high-level parallel execution overview.
Inference Provider
See the inference-stage provider pattern and YAML wiring.
Parallel Config
Review local, local GPU, and cluster backend settings.
EnvironmentProvider API
Read the generated contract for local and worker env setup.
BaseRunner API
Read the generated contract for forward, writer hooks, and merge.
InferenceRunner API
Inspect the writer-style runner used by base inference.
