espnet3.systems.base.inference_runner.InferenceRunner
espnet3.systems.base.inference_runner.InferenceRunner
class espnet3.systems.base.inference_runner.InferenceRunner(provider: EnvironmentProvider, idx_key: str = 'utt_id', hyp_key: str | Sequence[str] = 'hyp', ref_key: str | Sequence[str] = 'ref', **kwargs)
Bases: BaseRunner
Inference runner with strict output-format validation.
This runner implements forward to call a recipe-provided output function. The key names are configurable via idx_key and hyp_key/ref_key. hyp_key and ref_key may be a single string or a list of strings to support multiple hypothesis/reference fields. idx_key is the key used to map each inference result to its source dataset index when writing SCP files.
Output format requirements. : - The result is a dict with the configured keys plus any extra fields.
- A sample identifier key must exist under
idx_keyso SCP outputs can map each result back to the corresponding dataset sample. - The sample identifier must be a single value, not a list or tuple.
hyp_keyandref_keyvalues may be scalars or lists/tuples. If lists are returned, each entry is written to its own SCP file (e.g.,hyp0.scp,hyp1.scp).
- Parameters:
- provider (EnvironmentProvider) – Provider that supplies dataset/model/env.
- idx_key (str) – Output dict key used as the sample identifier in SCP files. Defaults to
"utt_id". - hyp_key (str | Sequence *[*str ]) – Hypothesis key(s) expected in output.
- ref_key (str | Sequence *[*str ]) – Reference key(s) expected in output.
- **kwargs – Forwarded to
BaseRunner(e.g.,output_dir,batch_size,resume).
Example
>>> from espnet3.parallel.inference_provider import InferenceProvider
>>> class MyProvider(InferenceProvider):
... @staticmethod
... def build_dataset(config): return load_dataset(config)
... @staticmethod
... def build_model(config): return load_model(config)
>>> runner = InferenceRunner(
... MyProvider(config),
... output_dir="/exp/decode",
... idx_key="utt_id",
... hyp_key="hyp",
... ref_key="ref",
... )
>>> runner(range(len(test_dataset)))Initialize the inference runner with output key settings.
- Parameters:
- provider – Environment provider that supplies dataset/model/env.
- idx_key – Output dict key used as the sample identifier written in the first column of each SCP line. This ties each inference result back to its dataset sample. Defaults to
"utt_id". - hyp_key – Hypothesis key or keys expected in the output dict.
- ref_key – Reference key or keys expected in the output dict.
- **kwargs – Forwarded to
BaseRunner.
static close_writers(writers: Dict[str, Any]) → Dict[str, Any] | None
Close shard-local SCP files and report which output keys were written.
static forward(idx, dataset=None, model=None, **kwargs)
Run inference for one or more dataset items and return output dict(s).
- Parameters:
- idx – Integer index or an iterable of integer indices into the dataset.
- dataset – Dataset providing inference entries.
- model – Inference model callable on the configured input.
- **kwargs – Expects
input_keyand optionallyoutput_fn_path.model_kwargsmay be used to pass extra keyword arguments through to the underlying model callable.
- Returns: Dict containing
idxand output fields for a single item, or a list of dicts for batched inputs (as returned byoutput_fn). - Raises:
- RuntimeError – If required input settings are missing.
- KeyError – If required input keys are missing from the dataset item(s).
- RuntimeError – If batched inference fails; includes guidance to disable batching when unsupported.
Notes
input_keymay be a string or a list/tuple of strings.- Batched inputs are passed to the model as lists per key; padding is the model’s responsibility.
##
Example
>>> # Single-item inference
>>> out = InferenceRunner.forward(
... 0, dataset=dataset, model=model,
... input_key="speech", output_fn_path="m.mod.out_fn"
... )
>>> # Batched inference
>>> out = InferenceRunner.forward(
... [0, 1], dataset=dataset, model=model,
... input_key=["speech", "text"], output_fn_path="m.mod.out_fn"
... )merge(shard_dirs: List[Path]) → Dict[str, Any] | None
Merge per-shard SCP files into the test-set output directory.
Reads field_keys.txt from each shard to discover output field names, then concatenates each <field>.scp across shards in shard order into output_dir / shard_subdir.
- Parameters:shard_dirs – Completed shard directories in shard-id order.
- Returns: Empty dict on success (outputs are on disk).
- Return type: Dict[str, Any]
- Raises:RuntimeError – If no output keys are found across all shards.
static open_writers(shard_dir: Path | None, output_artifacts: Dict[str, dict] | None = None, **env) → Dict[str, Any]
Open per-shard SCP writers for worker-side inference outputs.
resolve_idx_key(output: Dict[str, Any]) → str
Validate that the configured sample-identifier key exists in output.
- Parameters:output – A single inference result dict.
- Returns: The
idx_keyattribute when present inoutput. - Return type: str
- Raises:ValueError – If
idx_keyis not found inoutput.
static write_record(writers: Dict[str, Any], result: Any, state: Dict[str, Any], idx_key: str = 'utt_id', output_keys=None, hyp_key=None, ref_key=None, **env) → None
Validate one forward result and stream it into shard-local SCP files.
