espnet2.tts.utils.duration_calculator.DurationCalculator
espnet2.tts.utils.duration_calculator.DurationCalculator
class espnet2.tts.utils.duration_calculator.DurationCalculator(*args, **kwargs)
Bases: Module
Duration calculator for ESPnet2.
This module implements a duration calculator that converts attention weights from a sequence-to-sequence model into durations and focus rate values. It can handle different input shapes corresponding to different model architectures, such as Tacotron 2 and Transformer-based models.
None
- Parameters:att_ws (torch.Tensor) – Attention weight tensor with shape (T_feats, T_text) for Tacotron 2 or (#layers, #heads, T_feats, T_text) for Transformer models.
- Returns: A tuple containing: : - LongTensor: Duration of each input (T_text,).
- Tensor: Focus rate value.
- Return type: Tuple[torch.LongTensor, torch.Tensor]
- Raises:ValueError – If att_ws is not a 2D or 4D tensor.
####### Examples
>>> calculator = DurationCalculator()
>>> att_ws_tacotron = torch.rand(100, 50) # Example for Tacotron 2
>>> duration, focus_rate = calculator(att_ws_tacotron)
>>> print(duration.shape) # Should print: torch.Size([50])
>>> print(focus_rate) # Focus rate value
>>> att_ws_transformer = torch.rand(6, 8, 100, 50) # Example for Transformer
>>> duration, focus_rate = calculator(att_ws_transformer)
>>> print(duration.shape) # Should print: torch.Size([50])
>>> print(focus_rate) # Focus rate value
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(att_ws: Tensor) → Tuple[Tensor, Tensor]
Convert attention weight to durations.
This method processes attention weights and computes the duration and focus rate from the given attention weight tensor. It supports both Tacotron 2 and transformer models based on the shape of the input tensor.
- Parameters:att_ws (torch.Tensor) – Attention weight tensor. It can have one of the following shapes:
- (T_feats, T_text) for Tacotron 2.
- (#layers, #heads, T_feats, T_text) for transformer models.
- Returns: A tuple containing: : - Duration of each input (T_text,) as a LongTensor.
- Focus rate value as a Tensor.
- Return type: Tuple[torch.LongTensor, torch.Tensor]
- Raises:ValueError – If att_ws is not a 2D or 4D tensor.
####### Examples
>>> duration_calculator = DurationCalculator()
>>> att_ws_tacotron = torch.rand(100, 50) # Example for Tacotron 2
>>> duration, focus_rate = duration_calculator(att_ws_tacotron)
>>> print(duration.shape) # Should print: torch.Size([50])
>>> print(focus_rate) # Focus rate value
>>> att_ws_transformer = torch.rand(6, 8, 100, 50) # Example for transformer
>>> duration, focus_rate = duration_calculator(att_ws_transformer)
>>> print(duration.shape) # Should print: torch.Size([50])
>>> print(focus_rate) # Focus rate value