espnet3.components.data.collect_stats.collect_stats_batch
espnet3.components.data.collect_stats.collect_stats_batch
espnet3.components.data.collect_stats.collect_stats_batch(idxs: List[int], model=None, dataset=None, collate_fn=None, device: device | None = None, write_collected_feats: bool = False, collect_stats_kwargs: Dict[str, Any] | None = None)
Collect feature statistics for one batch of dataset indices.
This function is the low-level batch worker used by collect_stats(). It reads dataset items, builds one collated batch, calls model.collect_feats(...), and accumulates per-feature sums, squared sums, counts, and shape metadata.
- Parameters:
- idxs – Dataset indices to process as one batch.
- model – Model instance that provides a callable
collect_featsmethod. - dataset – Dataset or dataset-like object indexed by
idxs. - collate_fn – Collate function that returns
(uids, batch_dict). - device – Device used for tensor inputs passed to
model.collect_feats. - write_collected_feats – Whether to return the collected feature arrays in addition to aggregated statistics.
- collect_stats_kwargs – Extra keyword arguments forwarded to
model.collect_feats. Keys must not overlap with collated batch tensor names.
- Returns:
(stats, shape_info)whenwrite_collected_featsisFalse. Returns(stats, shape_info, feats)when it isTrue.statsstores per-featuresum,sq, andcountvalues.shape_infomaps each feature key touid -> shapestrings. - Return type: tuple
- Raises:
- RuntimeError – If
collate_fndoes not return(uids, batch_dict). - ValueError – If
collect_stats_kwargsconflicts with batch tensor names.
- RuntimeError – If
Notes
If the dataset exposes use_espnet_preprocessor=True, this function expects each dataset item to be (uid, sample).
Examples
stats, shape_info = collect_stats_batch( : [0, 1, 2, 3], model=model, dataset=dataset, collate_fn=collate_fn, device=torch.device(“cpu”),
)
