espnet3.components.data.collect_stats.CollectStatsRunner
espnet3.components.data.collect_stats.CollectStatsRunner
class espnet3.components.data.collect_stats.CollectStatsRunner(provider: EnvironmentProvider, output_dir: str | Path, mode: str, write_collected_feats: bool = False, **kwargs)
Bases: BaseRunner
Execute collect-stats batches and merge shard outputs.
The runner delegates batch execution to collect_stats_batch(), persists per-shard metadata and optional collected features, and merges the shard outputs into the final mode directory.
Initialize the runner.
- Parameters:
- provider – Environment provider that builds the dataset and model.
- output_dir – Root directory for collect-stats outputs.
- mode – Dataset split name used as the shard subdirectory.
- write_collected_feats – Whether to persist raw collected features.
- **kwargs – Extra arguments forwarded to
BaseRunner.
static close_writers(writers: Dict[str, Any]) → Dict[str, Any] | None
Close shard writers and flush shard summary files.
Writes shape_keys.txt, feat_keys_written.txt, stats_keys.txt, and {key}_stats.npz for every accumulated feature key to the shard directory.
- Parameters:writers – Writer state produced by
open_writers()and populated bywrite_record(). - Returns: None
static forward(batch_indices: Iterable[int] | int, dataset, model, collate_fn, device, write_collected_feats: bool = False, collect_stats_kwargs: Dict[str, Any] | None = None, **env)
Run collect-stats for one batch index group.
- Parameters:
- batch_indices – One or more dataset indices forming a single batch.
- dataset – Dataset indexed by the provided indices.
- model – Model with a callable
collect_featsmethod. - collate_fn – Collate function returning
(uids, batch_dict). - device – Device to move input tensors onto before inference.
- write_collected_feats – Whether to include raw feature arrays in the return value.
- collect_stats_kwargs – Extra keyword arguments forwarded to
model.collect_feats. - **env – Additional environment keys (ignored).
- Returns:
(stats, shape_info)or(stats, shape_info, feats)as returned bycollect_stats_batch(). - Return type: tuple
merge(shard_dirs: List[Path]) → Dict[str, Any]
Merge shard outputs into aggregated statistics for one split.
- Parameters:shard_dirs – List of shard directories produced by
close_writers(), one per parallel shard. - Returns: Aggregated totals with keys
"sum","sq", and"count", each mapping feature key to its accumulated value. Shape files and optional collected-feat SCP files are also concatenated intooutput_dir / mode. - Return type: dict
static open_writers(shard_dir: Path | None, write_collected_feats: bool = False, **env) → Dict[str, Any]
Create per-shard writer state.
- Parameters:
- shard_dir – Directory where shard output files are written.
- write_collected_feats – Whether to open feature writers in addition to shape-file handles.
- **env – Additional environment keys (ignored).
- Returns: Initial writers state with keys
_shard_dir,_write_feats,shape_handles, andfeat_writers. - Return type: dict
static write_record(writers: Dict[str, Any], result, state: Dict[str, Any], **env) → None
Accumulate one batch result into shard files and in-memory state.
- Parameters:
- writers – Writer state returned by
open_writers(). Updated in-place with running sums, shape file handles, and optional feature writers. - result – Return value of
forward()— either(stats, shape_info)or(stats, shape_info, feats). - state – Persistent in-memory accumulator for
sum,sq, andcountacross batches. - **env – Additional environment keys (ignored).
- writers – Writer state returned by
