espnet2.gan_tts.jets.length_regulator.GaussianUpsampling
espnet2.gan_tts.jets.length_regulator.GaussianUpsampling
class espnet2.gan_tts.jets.length_regulator.GaussianUpsampling(delta=0.1)
Bases: Module
Gaussian upsampling with fixed temperature as described in:
https://arxiv.org/abs/2010.04301
This module expands the hidden states based on the specified durations, using a Gaussian function to compute the attention weights for the upsampling process.
delta
The temperature parameter that controls the spread of the Gaussian function.
Type: float
Parameters:
- hs (Tensor) – Batched hidden state to be expanded (B, T_text, adim).
- ds (Tensor) – Batched token duration (B, T_text).
- h_masks (Tensor , optional) – Mask tensor for hidden states (B, T_feats).
- d_masks (Tensor , optional) – Mask tensor for durations (B, T_text).
Returns: Expanded hidden state (B, T_feat, adim).
Return type: Tensor
Raises:Warning – If the predicted durations include all zero sequences, a warning will be logged, and the first element will be filled with 1.
####### Examples
>>> import torch
>>> upsampler = GaussianUpsampling(delta=0.1)
>>> hs = torch.randn(2, 5, 256) # Example hidden states
>>> ds = torch.tensor([[1, 2, 0], [2, 1, 1]]) # Example durations
>>> expanded_hs = upsampler(hs, ds)
>>> print(expanded_hs.shape)
torch.Size([2, T_feat, 256]) # Shape will depend on the durations
NOTE
The behavior of this function may differ based on the input masks and the provided durations.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(hs, ds, h_masks=None, d_masks=None)
Gaussian upsampling with fixed temperature as in:
https://arxiv.org/abs/2010.04301
delta
Temperature parameter for the Gaussian function.
Type: float
Parameters:
- hs (Tensor) – Batched hidden state to be expanded (B, T_text, adim).
- ds (Tensor) – Batched token duration (B, T_text).
- h_masks (Tensor , optional) – Mask tensor (B, T_feats). Default is None.
- d_masks (Tensor , optional) – Mask tensor (B, T_text). Default is None.
Returns: Expanded hidden state (B, T_feat, adim).
Return type: Tensor
Raises:Warning – If the predicted durations include all zero sequences.
####### Examples
>>> model = GaussianUpsampling(delta=0.1)
>>> hs = torch.randn(2, 5, 10) # Example hidden states
>>> ds = torch.tensor([[1, 2, 0, 0, 0], [1, 0, 0, 0, 0]]) # Durations
>>> output = model.forward(hs, ds)
>>> print(output.shape) # Should print the shape of expanded hidden states
NOTE
The method handles cases where the duration tensor contains all zeros by filling the first element with 1, as this situation should not occur during teacher forcing.