espnet2.train.collate_fn.CommonCollateFn
espnet2.train.collate_fn.CommonCollateFn
class espnet2.train.collate_fn.CommonCollateFn(float_pad_value: float | int = 0.0, int_pad_value: int = -32768, not_sequence: Collection[str] = ())
Bases: object
A class that implements a common collate function for batching audio and text data. This class is particularly useful in scenarios where data needs to be collated into a format suitable for input into machine learning models, especially in speech processing tasks.
float_pad_value
The value used to pad float tensors.
- Type: Union[float, int]
int_pad_value
The value used to pad integer tensors.
- Type: int
not_sequence
A collection of keys that should not have lengths calculated.
Type: set
Parameters:
- float_pad_value (Union *[*float , int ] , optional) – The float value to use for padding. Defaults to 0.0.
- int_pad_value (int , optional) – The integer value to use for padding. Defaults to -32768.
- not_sequence (Collection *[*str ] , optional) – A collection of keys that should not be treated as sequences. Defaults to an empty tuple.
Returns: A tuple containing a list of unique identifiers and a dictionary of collated tensors.
Return type: Tuple[List[str], Dict[str, torch.Tensor]]
Examples
>>> from espnet2.train.collate_fn import CommonCollateFn
>>> collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
>>> data = [
... ("id1", {"speech": np.array([1.0, 2.0]), "text": np.array([1])}),
... ("id2", {"speech": np.array([3.0]), "text": np.array([2])}),
... ]
>>> result = collate_fn(data)
>>> print(result)
(['id1', 'id2'], {'speech': tensor(...), 'text': tensor(...)})
NOTE
This class is designed to work seamlessly with the common_collate_fn function for flexible handling of different data formats.
- Raises:
- AssertionError – If the input data does not match expected formats or
- contains incompatible keys. –