espnet2.samplers.category_balanced_sampler.CategoryBalancedSampler
espnet2.samplers.category_balanced_sampler.CategoryBalancedSampler
class espnet2.samplers.category_balanced_sampler.CategoryBalancedSampler(batch_size: int, min_batch_size: int = 1, drop_last: bool = False, category2utt_file: str | None = None, epoch: int = 1, **kwargs)
Bases: AbsSampler
Sampler that maintains an equal distribution of categories (i.e., classes)
within each minibatch. If the batch size is smaller than the number of classes, all samples in the minibatch will belong to different classes.
The key_file is a text file that describes each sample name. It should be formatted as follows:
utterance_id_a utterance_id_b utterance_id_c
The first column is referred to, so a ‘shape file’ can also be used, which has the following format:
utterance_id_a 100,80 utterance_id_b 400,80 utterance_id_c 512,80
batch_size
The size of each batch.
- Type: int
min_batch_size
The minimum size of each batch. Default is 1.
- Type: int
drop_last
Whether to drop the last batch if it’s smaller than batch_size. Default is False.
- Type: bool
category2utt_file
Path to the file mapping categories to utterances.
- Type: Optional[str]
epoch
The epoch number for seeding the random number generator.
Type: int
Parameters:
- batch_size (int) – The size of the minibatch.
- min_batch_size (int , optional) – Minimum size of the minibatch. Default is 1.
- drop_last (bool , optional) – If True, drop the last batch if it’s smaller than batch_size. Default is False.
- category2utt_file (str , optional) – Path to the file mapping categories to utterances. Must be provided.
- epoch (int , optional) – The epoch number for random seed. Default is 1.
- **kwargs – Additional keyword arguments for customization.
Returns: An iterator that yields batches of utterances.
Return type: Iterator[Tuple[str, …]]
Examples
>>> sampler = CategoryBalancedSampler(batch_size=4, category2utt_file='path/to/file.txt')
>>> for batch in sampler:
... print(batch)
NOTE
The random seed is initialized based on the provided epoch number to ensure reproducibility across different runs.