espnet3.systems.base.inference_provider.InferenceProvider
espnet3.systems.base.inference_provider.InferenceProvider
class espnet3.systems.base.inference_provider.InferenceProvider(config: DictConfig, , params: Dict[str, Any] | None = None)
Bases: EnvironmentProvider, ABC
EnvironmentProvider specialized for dataset/model inference setup.
This implementation focuses on constructing just the dataset and model and returning them as environment entries. It is suitable for both local and distributed execution.
Design: : - build_* helpers are defined as @staticmethod/class methods so they are easily serializable and reusable on workers.
- The worker setup function must not capture
selfto remain pickle-safe for Dask.
- Parameters:
- config (DictConfig) β Hydra configuration used to build dataset/model.
- params (Dict *[*str , Any ] | None) β Optional additional key-value pairs that will be merged into the returned environment (e.g., device, tokenizer, beam size).
Notes
- Subclasses must implement
build_datasetandbuild_model. self.config.update(self.params)allows lightweight overrides (e.g., runtime overrides) but avoid mutating deep structures unless intended.
abstractmethod static build_dataset(config: DictConfig)
Construct and return the dataset instance.
Implemented by subclasses to build a dataset from config. During parallel or distributed execution, the config object passed here is the configuration that the user passed when instantiating the class.
- Parameters:config (DictConfig) β Configuration object for dataset parameters (e.g., data directory, preprocessing pipeline, features, split).
- Returns: Dataset object (type defined by subclass).
- Return type: Any
- Raises:NotImplementedError β Always in the base class; implement in subclass.
Example
>>> # Minimal sketch; actual keys depend on your subclass
>>> from omegaconf import OmegaConf
>>> cfg = OmegaConf.create({
>>> "dataset": {"path": "data/test", "split": "test"}
>>> })
>>> ds = MyInferenceProvider.build_dataset(cfg)Notes
- Keep dataset initialization lightweight by using lazy loading or memory mapping when possible.
- Rely on fields already present in
configinstead of reading global state whenever possible.
build_env_local() β Dict[str, Any]
Build the environment once on the driver for local inference.
- Returns: Environment dict with at least: : > -
"dataset": The instantiated dataset."model": The instantiated model. <br/> Any additional fields fromparamsare also included.
- Return type: Dict[str, Any]
Example
>>> provider = InferenceProvider(cfg, params={"device": "cuda"})
>>> env = provider.build_env_local()
>>> env.keys()
dict_keys(["dataset", "model", "device"])abstractmethod static build_model(config: DictConfig)
Construct and return the model instance.
Implemented by subclasses to build a model from config. During parallel or distributed execution, the config object passed here is the configuration that the user passed when instantiating the class.
- Parameters:cfg (DictConfig) β Configuration.
- Returns: Model object (type defined by subclass).
- Return type: Any
- Raises:NotImplementedError β Always in the base class; implement in subclass.
Example
>>> # Minimal sketch; actual keys depend on your subclass
>>> from omegaconf import OmegaConf
>>> cfg = OmegaConf.create({
>>> "model": {"checkpoint": "exp/model.pth", "device": "cpu"}
>>> })
>>> model = MyInferenceProvider.build_model(cfg)Notes
- This method should handle loading weights and placing the model on the appropriate device.
- Do not perform training/optimization here, this is for inference setup only.
make_worker_setup_fn() β Callable[[], Dict[str, Any]]
Return a Dask worker setup function that builds dataset/model.
The returned function is executed once per worker process and must not capture self. It closes over the immutable config and params snapshot, then constructs the environment on each worker.
- Returns: A zero-argument setup function that returns
{"dataset": ..., "model": ..., **params}. - Return type: Callable[[], Dict[str, Any]]
Example
>>> provider = InferenceProvider(cfg, params={"device": "cuda:0"})
>>> setup_fn = provider.make_worker_setup_fn()
>>> env = setup_fn()
>>> "dataset" in env and "model" in env
True