ESPnet3 Inference Provider
ESPnet3 Inference Provider
InferenceProvider is the convenience provider for inference-time environment construction.
ESPnet3 currently has two related layers:
espnet3.parallel.inference_provider.InferenceProviderespnet3.systems.base.inference_provider.InferenceProvider
The stage-facing class used by inference.yaml is the second one:
That class is the main focus of this page.
It is a small provider with a clear job:
- build the dataset for one test set
- build the model
- return them as a plain env dictionary for the runner
This is the provider used by the base inference stage.
What it returns
The provider returns an env like this:
{
"dataset": dataset,
"model": model,
**params,
}dataset and model are the core entries.
Any extra params are merged into the env too. That makes it easy to inject small runtime values without changing the runner.
Where it fits
The flow is:
infer()resolves onetest_setInferenceProviderbuilds the dataset and model for that setInferenceRunner.forward(...)receives the env by keyword name- outputs are materialized under
inference_dir
So this provider is the bridge between:
inference.yaml- the instantiated dataset/model objects
- the runner that executes inference over sample indices
Base behavior
The base class already implements:
In many cases, you only need to subclass it if your dataset or model setup is special.
build_dataset(config)
See the build_dataset API docs for the exact contract.
The default implementation:
- instantiates
config.dataset - expects that object to be a
DataOrganizer - reads
config.test_set - returns
organizer.test[test_set]
That means the provider is designed for one concrete test split at a time.
dataset = instantiate(config.dataset)
return dataset.test[config.test_set]build_model(config)
See the build_model API docs for the exact contract.
The default implementation:
- resolves the visible device
- instantiates
config.model - passes
device=...into the model constructor
The device resolution logic checks:
config.deviceconfig.device_indexconfig.local_rankLOCAL_RANK- falls back to
cuda:0orcpu
This is important for local_gpu or worker-per-GPU execution.
Example: use the default behavior
If your inference.yaml already defines:
datasetmodelproviderrunner
then you can usually point provider at the base provider directly.
provider:
_target_: espnet3.systems.base.inference_provider.InferenceProvider
runner:
_target_: espnet3.systems.base.inference_runner.InferenceRunner
dataset:
_target_: espnet3.components.data.data_organizer.DataOrganizer
test:
- name: test_clean
data_src: egs3.mini_an4.asr.dataset.builder
data_src_args:
split: test
model:
_target_: egs3.mini_an4.asr.src.inference.SimpleInferenceModelIn that setup:
infer()setsconfig.test_set = "test_clean"- the provider returns
organizer.test["test_clean"] - the model is instantiated once per process
Example: override dataset construction
Override build_dataset() when:
- your test dataset is not under
organizer.test - you need extra filtering
- you need a wrapper dataset for inference-only normalization
from hydra.utils import instantiate
from espnet3.systems.base.inference_provider import InferenceProvider
class MyInferenceProvider(InferenceProvider):
@staticmethod
def build_dataset(config):
organizer = instantiate(config.dataset)
dataset = organizer.test[config.test_set]
return MyDatasetWrapper(dataset, normalize_text=True)This keeps the rest of the provider behavior unchanged.
Example: override model construction
Override build_model() when:
- the model needs custom preload logic
- you need to load auxiliary artifacts
- you want to attach runtime helpers before inference starts
from hydra.utils import instantiate
from espnet3.systems.base.inference_provider import InferenceProvider
class MyInferenceProvider(InferenceProvider):
@staticmethod
def build_model(config):
model = instantiate(config.model, device="cpu")
model.load_extra_assets(config.extra_assets)
model.eval()
return modelIf you override this method, keep device handling explicit. Do not assume physical GPU ids from CUDA_VISIBLE_DEVICES.
Example: use params for small runtime values
params are merged into the env after dataset/model construction.
That is useful for small values that the runner needs by name.
provider = InferenceProvider(
inference_config=cfg,
params={
"beam_size": 8,
"return_attention": False,
},
)Then the runner can receive them directly:
@staticmethod
def forward(idx, dataset, model, beam_size, return_attention, **env):
sample = dataset[idx]
return model(sample, beam_size=beam_size, return_attention=return_attention)Use params for lightweight runtime flags. Do not use them to smuggle large driver-side objects into workers.
Local vs worker behavior
In local mode:
build_env_local()runs once on the driver
In Dask mode:
build_worker_setup_fn()returns a zero-argument setup function- that setup function runs once per worker
- each worker builds its own dataset/model pair
This is why the base implementation avoids capturing self in the worker setup closure.
When to subclass
Use the base class as-is when:
config.datasetis a normalDataOrganizerconfig.modelis directly instantiable- one
test_setmaps cleanly to one dataset
Subclass it when:
- dataset selection is custom
- model loading needs extra steps
- worker-local setup needs extra artifacts
Common mistakes
- Forgetting that
build_dataset()expectsconfig.test_set - Returning the full organizer instead of one dataset split
- Putting large objects into
params - Hard-coding a physical GPU id in
build_model() - Reimplementing
build_env_local()when overridingbuild_dataset()orbuild_model()would be enough
Related pages
Provider / Runner
Read the base contract for providers, runners, and env injection.
Parallel Config
Configure `local`, `local_gpu`, SSH, or HPC backends.
Inference Config
See how `provider`, `runner`, `dataset`, and `model` are written in YAML.
Inference Stage
See how the provider is used inside the stage entrypoint.
Systems InferenceProvider API
Read the generated docstring contract for the stage-facing provider.
InferenceRunner API
Read how env keys are consumed by the base inference runner.
Parallel InferenceProvider API
Compare the lower-level parallel provider base class.
