espnet3.components.dataset.CombinedDataset
espnet3.components.dataset.CombinedDataset
class espnet3.components.dataset.CombinedDataset(datasets: List[Any], transforms: List[Tuple[Callable, Callable]], use_espnet_preprocessor: bool = False)
Bases: object
Combines multiple datasets into a single unified dataset-like interface.
This class supports seamless access to multiple datasets as if they were one. Each dataset can be paired with a transform and a global preprocessor, which are applied sequentially to each sample. It also supports optional UID handling for ESPnet-style preprocessing.
- Parameters:
datasets (List *[*Any ]) β A list of dataset instances. Each must implement __getitem__ and __len__.
transforms (List *[*Tuple *[*Callable , Callable ] ]) β
A list of (transform, preprocessor) tuples. Each pair corresponds to the matching dataset in datasets.
- transform(sample) is applied first.
- Then preprocessor(uid, sample) or preprocessor(sample) is applied,
depending on use_espnet_preprocessor.
use_espnet_preprocessor (bool) β If True, applies the preprocessor as preprocessor(uid, sample). This is used for ESPnet AbsPreprocessor compatible pipelines.
get_text_available
True if all datasets implement get_text(idx).
- Type: bool
multiple_iterator
True if any dataset is a subclass of ShardedDataset.
- Type: bool
NOTE
At initialization, the first sample from each dataset is passed through its associated transform to check that all datasets produce dictionaries with the same set of keys. This ensures consistency across the combined dataset. An AssertionError is raised if the keys differ.
- Raises:
- IndexError β If a requested index is outside the range of the combined dataset.
- ValueError β If index is a non-integer string or cannot be cast to int.
- RuntimeError β If get_text() or shard() is called but not supported.
- AssertionError β If output keys from different datasets are inconsistent.
Example
>>> dataset = CombinedDataset(
... datasets=[ds1, ds2],
... transforms=[
... (transform1, preprocessor),
... (transform2, preprocessor),
... ],
... use_espnet_preprocessor=True
... )
>>> sample = dataset[5]
>>> print(sample["text"])get_text(idx)
Retrieve the target text string for a given index.
This method delegates to the underlying datasetβs get_text(idx) method. It is typically used for extracting text sequences for purposes such as training tokenizers or language models.
- Raises:RuntimeError β If not all datasets implement get_text(idx).
shard(shard_idx: int)
Return a sharded version of the combined dataset.
This is used when handling large datasets that are split into shards for efficiency and distributed processing (ESPnet multiple-iterator mode). All datasets must be subclasses of espnet3.data.dataset.ShardedDataset, and implement a shard() method.
- Parameters:shard_idx (int) β Index of the shard to retrieve.
- Returns: A new CombinedDataset containing the sharded datasets.
- Return type:CombinedDataset
- Raises:RuntimeError β If any dataset does not support sharding.
property use_espnet_collator
