espnet2.iterators.category_chunk_iter_factory.CategoryChunkIterFactory
espnet2.iterators.category_chunk_iter_factory.CategoryChunkIterFactory
class espnet2.iterators.category_chunk_iter_factory.CategoryChunkIterFactory(dataset, batch_size: int, batches: AbsSampler | Sequence[Sequence[Any]], chunk_length: int | str, chunk_shift_ratio: float = 0.5, num_cache_chunks: int = 1024, num_samples_per_epoch: int | None = None, seed: int = 0, shuffle: bool = False, num_workers: int = 0, collate_fn=None, pin_memory: bool = False, excluded_key_prefixes: List[str] | None = None, discard_short_samples: bool = True, default_fs: int | None = None, chunk_max_abs_length: int | None = None)
Bases: AbsIterFactory
Creates chunks from a sequence.
This class generates categorical balanced chunks for batches instead of per category. It is a modified version of the ChunkIterFactory and is designed to handle various chunk lengths and sampling frequencies.
category_sample_iter_factory
An instance of SequenceIterFactory to generate sequences for sampling.
- Type:SequenceIterFactory
num_cache_chunks
Maximum number of cached chunks.
- Type: int
chunk_lengths
List of possible chunk lengths derived from the input.
- Type: List[int]
chunk_max_abs_length
Maximum absolute length of chunks.
- Type: int
chunk_shift_ratio
Ratio for shifting chunks during sampling.
- Type: float
batch_size
Size of each batch.
- Type: int
seed
Seed for random number generation.
- Type: int
shuffle
Indicates whether to shuffle the data.
- Type: bool
default_fs
Default sampling frequency.
- Type: Optional[int]
discard_short_samples
Whether to discard samples shorter than the shortest chunk length.
- Type: bool
excluded_key_pattern
Regular expression pattern for keys to exclude from length consistency checks.
- Type: str
collate_fn
Function to collate batches.
Type: Optional[callable]
Parameters:
- dataset – The dataset to sample from.
- batch_size (int) – The size of each batch.
- batches (Union [AbsSampler , Sequence *[*Sequence *[*Any ] ] ]) – The batches to sample from, either as a sampler or a sequence.
- chunk_length (Union *[*int , str ]) – The length of each chunk, can be a single value or a range.
- chunk_shift_ratio (float , optional) – Ratio for shifting chunks (default is 0.5).
- num_cache_chunks (int , optional) – Number of chunks to cache (default is 1024).
- num_samples_per_epoch (Optional *[*int ] , optional) – Number of samples per epoch (default is None).
- seed (int , optional) – Seed for random number generation (default is 0).
- shuffle (bool , optional) – Whether to shuffle the data (default is False).
- num_workers (int , optional) – Number of workers for data loading (default is 0).
- collate_fn (Optional *[*callable ] , optional) – Function to collate batches (default is None).
- pin_memory (bool , optional) – Whether to pin memory (default is False).
- excluded_key_prefixes (Optional *[*List *[*str ] ] , optional) – List of prefixes for keys to exclude (default is None).
- discard_short_samples (bool , optional) – Whether to discard samples shorter than the shortest chunk length (default is True).
- default_fs (Optional *[*int ] , optional) – Default sampling frequency (default is None).
- chunk_max_abs_length (Optional *[*int ] , optional) – Maximum absolute length of chunks (default is None).
Returns: An iterator yielding tuples of IDs and corresponding batches of tensors.
Return type: Iterator[Tuple[List[str], Dict[str, torch.Tensor]]]
Raises:
- ValueError – If the chunk_length string is invalid or empty.
- RuntimeError – If sequences do not have the same length.
######### Examples
>>> batches = [["id1"], ["id2"], ...]
>>> batch_size = 128
>>> chunk_length = 1000
>>> iter_factory = CategoryChunkIterFactory(dataset, batches,
... batch_size, chunk_length)
>>> it = iter_factory.build_iter(epoch)
>>> for ids, batch in it:
... ...
NOTE
This class does not rebuild the sampler for each epoch, as the randomness from samplers is usually sufficient.
build_iter(epoch: int, shuffle: bool | None = None) → Iterator[Tuple[List[str], Dict[str, Tensor]]]
Builds an iterator that generates chunks from a dataset based on the specified parameters.
This method utilizes the internal category sample iterator factory to yield batches of data that are divided into chunks according to the specified chunk lengths and other settings.
- Parameters:
- epoch (int) – The current epoch number, used for random state initialization.
- shuffle (Optional *[*bool ]) – If True, shuffles the data. If None, uses the value from the instance variable.
- Yields:Iterator[Tuple[List[str], Dict[str, torch.Tensor]]] – A tuple containing a list of IDs and a dictionary of batched data tensors.
- Raises:RuntimeError – If the sequences in the batch do not have the same length according to the specified conditions.
######### Examples
>>> iter_factory = CategoryChunkIterFactory(dataset, batch_size=32,
... batches=[["id1"], ["id2"]], chunk_length=100)
>>> for ids, batch in iter_factory.build_iter(epoch=1):
... print(ids, batch)
NOTE
This iterator supports multiple chunk lengths and keeps chunks for each length until collecting the specified number of batches.
prepare_for_collate(id_list, batches)
Prepares the batches for collation by converting tensor values to numpy arrays.
This method takes a list of IDs and a corresponding batch of tensors, and returns a list of tuples, where each tuple contains an ID and a dictionary of numpy arrays representing the batch data.
- Parameters:
- id_list (List *[*str ]) – A list of identifiers corresponding to the batches.
- batches (Dict *[*str , List *[*torch.Tensor ] ]) – A dictionary where keys are the names of the batch components and values are lists of tensors.
- Returns: A list of tuples, each containing : an identifier and a dictionary of numpy arrays for the batch data.
- Return type: List[Tuple[str, Dict[str, np.ndarray]]]
######### Examples
>>> id_list = ['id1', 'id2']
>>> batches = {
... 'feature': [torch.tensor([[1, 2], [3, 4]]), torch.tensor([[5, 6]])],
... 'label': [torch.tensor([1]), torch.tensor([0])]
... }
>>> prepared_batches = prepare_for_collate(id_list, batches)
>>> print(prepared_batches)
[('id1', {'feature': array([[1, 2], [3, 4]]), 'label': array([1])}),
('id2', {'feature': array([[5, 6]]), 'label': array([0])})]