espnet2.train.collate_fn.HuBERTCollateFn
espnet2.train.collate_fn.HuBERTCollateFn
class espnet2.train.collate_fn.HuBERTCollateFn(float_pad_value: float | int = 0.0, int_pad_value: int = -32768, label_downsampling: int = 1, pad: bool = False, rand_crop: bool = True, crop_audio: bool = True, not_sequence: Collection[str] = (), window_size: float = 25, window_shift: float = 20, sample_rate: float = 16)
Bases: CommonCollateFn
Functor class for collating audio and label data for HuBERT training.
This class inherits from CommonCollateFn and provides additional functionalities specific to HuBERT, such as label downsampling, random cropping, and audio cropping. It ensures that the audio and label data are correctly processed and padded before being passed to the model.
float_pad_value
Value used for padding float tensors (default: 0.0).
- Type: Union[float, int]
int_pad_value
Value used for padding integer tensors (default: -32768).
- Type: int
label_downsampling
Factor by which to downsample labels (default: 1).
- Type: int
pad
If True, pad audio to the maximum length in the batch (default: False).
- Type: bool
rand_crop
If True, apply random cropping to the audio and labels (default: True).
- Type: bool
crop_audio
If True, crop the audio to match the desired length (default: True).
- Type: bool
not_sequence
Keys that should not have lengths calculated (default: ()).
- Type: Collection[str]
window_size
Size of the window for audio processing in ms (default: 25).
- Type: float
window_shift
Shift of the window for audio processing in ms (default: 20).
- Type: float
sample_rate
Sample rate of the audio in kHz (default: 16).
Type: float
Parameters:
- float_pad_value (Union *[*float , int ]) – Value used for padding float tensors (default: 0.0).
- int_pad_value (int) – Value used for padding integer tensors (default: -32768).
- label_downsampling (int) – Factor by which to downsample labels (default: 1).
- pad (bool) – If True, pad audio to the maximum length in the batch (default: False).
- rand_crop (bool) – If True, apply random cropping to the audio and labels (default: True).
- crop_audio (bool) – If True, crop the audio to match the desired length (default: True).
- not_sequence (Collection *[*str ]) – Keys that should not have lengths calculated (default: ()).
- window_size (float) – Size of the window for audio processing in ms (default: 25).
- window_shift (float) – Shift of the window for audio processing in ms (default: 20).
- sample_rate (float) – Sample rate of the audio in kHz (default: 16).
Returns: A tuple containing the unique identifiers of the data and a dictionary of processed tensors.
Return type: Tuple[List[str], Dict[str, torch.Tensor]]
Raises:
- AssertionError – If the required keys “speech” and “text” are not
- present in the input data. –
Examples
>>> collate_fn = HuBERTCollateFn(pad=True, rand_crop=False)
>>> batch = collate_fn(data)
>>> print(batch)
NOTE
The __call__ method asserts that the data contains “speech” and “text” keys and processes the data accordingly.