espnet3.components.data.collect_stats.CollectStatsInferenceProvider
Less than 1 minute
espnet3.components.data.collect_stats.CollectStatsInferenceProvider
class espnet3.components.data.collect_stats.CollectStatsInferenceProvider(, model_config, dataset_config, dataloader_config, mode: str, task: str | None = None, shard_idx: int | None = None, params: Dict[str, Any] | None = {})
Bases: EnvironmentProvider
EnvironmentProvider tailored for collect-stats jobs.
- Parameters:
- model_config – Configuration used to instantiate the model that exposes
collect_feats. - dataset_config – Organizer configuration for the target dataset split.
- dataloader_config – Dataloader configuration controlling collation and sharding.
- mode (str) – Dataset split name (e.g.,
"train"or"valid"). - task (str | None) – Optional ESPnet task name to resolve models.
- shard_idx (int | None) – Optional shard index when running per-shard.
- params (Dict *[*str , Any ] | None) – Additional overrides merged into the environment (e.g.,
write_collected_feats).
- model_config – Configuration used to instantiate the model that exposes
Example
>>> provider = CollectStatsInferenceProvider(
... model_config=model_cfg,
... dataset_config=ds_cfg,
... dataloader_config=dl_cfg,
... mode="train",
... )
>>> env = provider.build_env_local()
>>> set(env.keys()) >= {"dataset", "model"}
TrueInitialize CollectStatsInferenceProvider object.
build_env_local() → Dict[str, Any]
Build the environment once on the driver for local inference.
build_worker_setup_fn()
Return a Dask worker setup function that builds dataset/model.
