espnet2.main_funcs.collect_stats.collect_stats
espnet2.main_funcs.collect_stats.collect_stats
espnet2.main_funcs.collect_stats.collect_stats(model: AbsESPnetModel | None, train_iter: Iterable[Tuple[List[str], Dict[str, Tensor]]], valid_iter: Iterable[Tuple[List[str], Dict[str, Tensor]]], output_dir: Path, ngpu: int | None, log_interval: int | None, write_collected_feats: bool) → None
Perform on collect_stats mode.
This function runs to derive shape information from data and gather statistics. It is used before executing the training process. It processes both training and validation datasets, collecting necessary statistics for each feature key.
espnet2.main_funcs.collect_stats.model
The ESPnet model to use for feature extraction. If None, feature extraction is skipped.
- Type: Union[AbsESPnetModel, None]
train_iter (DataLoader and Iterable[Tuple[List[str], Dict[str,
torch.Tensor]]]): The training data iterator providing batches of data.
valid_iter (DataLoader and Iterable[Tuple[List[str], Dict[str,
torch.Tensor]]]): The validation data iterator providing batches of data.
espnet2.main_funcs.collect_stats.output_dir
The directory where collected statistics will be saved.
- Type: Path
espnet2.main_funcs.collect_stats.ngpu
The number of GPUs to use for computation. If None, CPU will be used.
- Type: Optional[int]
espnet2.main_funcs.collect_stats.log_interval
Interval for logging progress. If None, it will be automatically set based on the number of iterations.
- Type: Optional[int]
espnet2.main_funcs.collect_stats.write_collected_feats
Flag indicating whether to write collected features as npy files.
Type: bool
Parameters:
- model – The model to collect statistics from.
- train_iter – The iterator for the training dataset.
- valid_iter – The iterator for the validation dataset.
- output_dir – The output directory for saving statistics.
- ngpu – The number of GPUs to use.
- log_interval – The logging interval for iterations.
- write_collected_feats – Flag to determine if features should be written.
Returns: This function does not return any value.
Return type: None
Raises:TypeError – If the train_iter or valid_iter is not iterable.
Examples
>>> from espnet2.train.abs_espnet_model import AbsESPnetModel
>>> from torch.utils.data import DataLoader
>>> model = AbsESPnetModel()
>>> train_loader = DataLoader(...)
>>> valid_loader = DataLoader(...)
>>> collect_stats(model, train_loader, valid_loader, Path('./output'),
... ngpu=1, log_interval=10, write_collected_feats=True)
NOTE
Ensure that the model is properly initialized before calling this function, and that the input data is correctly formatted.