espnet2.train.collate_fn.common_collate_fn
Less than 1 minute
espnet2.train.collate_fn.common_collate_fn
espnet2.train.collate_fn.common_collate_fn(data: Collection[Tuple[str, Dict[str, ndarray]]], float_pad_value: float | int = 0.0, int_pad_value: int = -32768, not_sequence: Collection[str] = ()) → Tuple[List[str], Dict[str, Tensor]]
Concatenate ndarray-list to an array and convert to torch.Tensor.
This function is used to collate a batch of data into a single tensor. It takes care of padding and ensures that the input data is properly structured for further processing.
- Parameters:
- data (Collection *[*Tuple *[*str , Dict *[*str , np.ndarray ] ] ]) – A collection of tuples, where each tuple contains a unique identifier and a dictionary of NumPy arrays. The dictionary should have consistent keys across all samples.
- float_pad_value (Union *[*float , int ] , optional) – The value used to pad floating-point tensors. Defaults to 0.0.
- int_pad_value (int , optional) – The value used to pad integer tensors. Defaults to -32768.
- not_sequence (Collection *[*str ] , optional) – A collection of keys that should not have their lengths calculated. Defaults to an empty collection.
- Returns: A tuple where the first element is a list of unique identifiers and the second element is a dictionary of tensors, where each tensor corresponds to the padded data.
- Return type: Tuple[List[str], Dict[str, torch.Tensor]]
Examples
>>> from espnet2.samplers.constant_batch_sampler import ConstantBatchSampler
>>> import espnet2.tasks.abs_task
>>> from espnet2.train.dataset import ESPnetDataset
>>> sampler = ConstantBatchSampler(...)
>>> dataset = ESPnetDataset(...)
>>> keys = next(iter(sampler))
>>> batch = [dataset[key] for key in keys]
>>> batch = common_collate_fn(batch)
>>> model(**batch)
Note that the dict-keys of batch are propagated from that of the dataset as they are.
- Raises:
- AssertionError – If the keys of the dictionaries in the data
- collection do not match or if any key ends with "_lengths". –