espnet2.samplers.folded_batch_sampler.FoldedBatchSampler
espnet2.samplers.folded_batch_sampler.FoldedBatchSampler
class espnet2.samplers.folded_batch_sampler.FoldedBatchSampler(batch_size: int, shape_files: Tuple[str, ...] | List[str], fold_lengths: Sequence[int], min_batch_size: int = 1, sort_in_batch: str = 'descending', sort_batch: str = 'ascending', drop_last: bool = False, utt2category_file: str | None = None)
Bases: AbsSampler
FoldedBatchSampler is a custom batch sampler for handling variable-length inputs
by folding batches based on specified lengths and categories.
This sampler organizes data into batches while ensuring that each batch meets the specified size requirements and is sorted according to the desired order.
batch_size
The size of each batch.
- Type: int
shape_files
A list of files that contain shape information for each utterance.
- Type: Union[Tuple[str, …], List[str]]
sort_in_batch
The sorting order for elements within each batch. Should be either “ascending” or “descending”.
- Type: str
sort_batch
The sorting order for batches. Should be either “ascending” or “descending”.
- Type: str
drop_last
Whether to drop the last incomplete batch if its size is less than batch_size.
- Type: bool
batch_list
A list containing the batches, each represented as a tuple of utterance keys.
Type: List[Tuple[str, …]]
Parameters:
- batch_size (int) – The number of samples in each batch.
- shape_files (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – Paths to the files that contain the shape information for each utterance.
- fold_lengths (Sequence *[*int ]) – A sequence of lengths used to determine the folding of batches.
- min_batch_size (int , optional) – The minimum number of samples in a batch. Defaults to 1.
- sort_in_batch (str , optional) – Sorting order for elements within a batch. Defaults to “descending”.
- sort_batch (str , optional) – Sorting order for batches. Defaults to “ascending”.
- drop_last (bool , optional) – Whether to drop the last batch if it is smaller than batch_size. Defaults to False.
- utt2category_file (Optional *[*str ] , optional) – An optional file that maps utterances to categories.
Returns: An iterator that yields batches of utterance keys.
Return type: Iterator[Tuple[str, …]]
Raises:
- ValueError – If sort_batch or sort_in_batch is not “ascending” or
- "descending". –
- RuntimeError – If there are mismatches in the keys between shape files or
- if no batches can be formed. –
Examples
>>> sampler = FoldedBatchSampler(
... batch_size=4,
... shape_files=["shapes1.txt", "shapes2.txt"],
... fold_lengths=[100, 200],
... min_batch_size=2,
... sort_in_batch="ascending",
... sort_batch="descending"
... )
>>> for batch in sampler:
... print(batch)
NOTE
This sampler is particularly useful for tasks where the lengths of the input data vary significantly, such as in speech processing or natural language processing tasks.