espnet3.components.data.collect_stats.collect_stats_batch
Less than 1 minute
espnet3.components.data.collect_stats.collect_stats_batch
espnet3.components.data.collect_stats.collect_stats_batch(idxs: List[int], , model=None, dataset=None, collate_fn=None, device: device | None = None, write_collected_feats: bool = False, collect_stats_kwargs: Dict[str, Any] | None = None)
Compute feature statistics for a batch of dataset items.
This helper: : - Materializes a list of dataset items for the given indices.
- Collates them into a mini-batch using
collate_fn. - Runs
model.collect_feats(...)undertorch.no_grad(). - Accumulates per-feature sums, squared sums, and counts, and records per-utterance feature shapes for writing
*_shapefiles.
- Parameters:
- idxs (List *[*int ]) – Dataset indices to process as one batch.
- model – Model instance that provides a callable
collect_featsmethod. - dataset – Dataset providing
__getitem__. - collate_fn – Collate function returning
(uids, batch_dict). - device (torch.device , optional) – Device to move batch tensors to.
- write_collected_feats (bool) – If
True, also returns raw collected feature arrays for writing (e.g., viaNpyScpWriter). - collect_stats_kwargs (Dict *[*str , Any ] , optional) – Extra keyword arguments forwarded to
model.collect_feats(must not collide with batch keys).
- Returns:
(stats, shape_info)whenwrite_collected_feats=False, otherwise(stats, shape_info, feats)where:statsmaps feature-key ->{"sum": ..., "sq": ..., "count": ...}shape_infomaps feature-key ->{uid: "shape_csv"}featsmaps feature-key -> numpy array batch outputs
- Return type: Tuple[dict, dict] | Tuple[dict, dict, dict]
- Raises:
- RuntimeError – If
collate_fndoes not return(uids, batch_dict). - ValueError – If
collect_stats_kwargsoverlaps with batch tensor keys.
- RuntimeError – If
Example
>>> stats, shapes = batch_collect_stats(
... [0, 1, 2, 3],
... model=model,
... dataset=dataset,
... collate_fn=collate_fn,
... device=torch.device("cpu"),
... )