espnet2.samplers.sorted_batch_sampler.SortedBatchSampler
espnet2.samplers.sorted_batch_sampler.SortedBatchSampler
class espnet2.samplers.sorted_batch_sampler.SortedBatchSampler(batch_size: int, shape_file: str, sort_in_batch: str = 'descending', sort_batch: str = 'ascending', drop_last: bool = False)
Bases: AbsSampler
BatchSampler with sorted samples by length.
This sampler is designed to create batches of samples sorted by their lengths. It can sort samples either in ascending or descending order within a batch and allows for sorting of batches as well.
batch_size
The size of each batch.
- Type: int
shape_file
Path to the file containing the shape information for each sample.
- Type: str
sort_in_batch
Defines the sorting order for samples within each batch. Can be ‘descending’, ‘ascending’, or None.
- Type: str
sort_batch
Defines the sorting order for the batches. Can be ‘ascending’ or ‘descending’.
- Type: str
drop_last
If True, drop the last batch if it is smaller than the batch size.
Type: bool
Parameters:
- batch_size (int) – The size of each batch. Must be greater than 0.
- shape_file (str) – Path to the file containing the shape information for each sample.
- sort_in_batch (str) – ‘descending’, ‘ascending’ or None for sorting samples within each batch.
- sort_batch (str) – ‘ascending’ or ‘descending’ for sorting the batches.
- drop_last (bool) – If True, the last batch will be dropped if it has fewer than batch_size samples.
Raises:
- ValueError – If sort_in_batch or sort_batch is not one of the expected values.
- RuntimeError – If no samples or batches are found.
Examples
>>> sampler = SortedBatchSampler(batch_size=32, shape_file='shapes.csv',
... sort_in_batch='ascending',
... sort_batch='descending')
>>> for batch in sampler:
... print(batch)
NOTE
The shape file should be a CSV file where each line corresponds to a sample and contains its length.