espnet2.samplers.num_elements_batch_sampler.NumElementsBatchSampler
espnet2.samplers.num_elements_batch_sampler.NumElementsBatchSampler
class espnet2.samplers.num_elements_batch_sampler.NumElementsBatchSampler(batch_bins: int, shape_files: Tuple[str, ...] | List[str], min_batch_size: int = 1, sort_in_batch: str = 'descending', sort_batch: str = 'ascending', drop_last: bool = False, padding: bool = True)
Bases: AbsSampler
NumElementsBatchSampler is a batch sampler that creates mini-batches of data
based on the number of elements (bins) specified. This sampler is designed to work with variable-length sequences, ensuring that each batch contains a specified maximum number of bins while maintaining a minimum batch size. The samples can be sorted in various orders within each batch and across batches.
batch_bins
The maximum number of bins allowed per batch.
- Type: int
shape_files
List of paths to files containing sequence lengths.
- Type: Union[Tuple[str, …], List[str]]
sort_in_batch
The order in which to sort elements within each batch; options are “ascending” or “descending”.
- Type: str
sort_batch
The order in which to sort batches; options are “ascending” or “descending”.
- Type: str
drop_last
Whether to drop the last incomplete batch if it has fewer than min_batch_size elements.
- Type: bool
batch_list
The final list of batches created by the sampler.
Type: List[Tuple[str, …]]
Parameters:
- batch_bins (int) – Maximum number of bins in each batch.
- shape_files (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of shape files containing sequence lengths.
- min_batch_size (int , optional) – Minimum number of samples in a batch. Defaults to 1.
- sort_in_batch (str , optional) – Sorting order for elements in a batch. Can be “ascending” or “descending”. Defaults to “descending”.
- sort_batch (str , optional) – Sorting order for batches. Can be “ascending” or “descending”. Defaults to “ascending”.
- drop_last (bool , optional) – If True, drop the last batch if it is smaller than min_batch_size. Defaults to False.
- padding (bool , optional) – If True, ensures all features have the same dimension across the corpus. Defaults to True.
Returns: An iterator over the batches.
Return type: Iterator[Tuple[str, …]]
Raises:
- ValueError – If sort_batch or sort_in_batch is not “ascending” or
- "descending". –
- RuntimeError – If the keys in the shape files do not match or if no
- batches can be created. –
Examples
>>> sampler = NumElementsBatchSampler(batch_bins=10,
... shape_files=["file1.csv", "file2.csv"])
>>> for batch in sampler:
... print(batch)
NOTE
This sampler is useful for training models with variable-length input sequences, such as in speech or language processing tasks.