espnet2.gan_tts.parallel_wavegan.upsample.UpsampleNetwork
About 1 min
espnet2.gan_tts.parallel_wavegan.upsample.UpsampleNetwork
class espnet2.gan_tts.parallel_wavegan.upsample.UpsampleNetwork(upsample_scales: List[int], nonlinear_activation: str | None = None, nonlinear_activation_params: Dict[str, Any] = {}, interpolate_mode: str = 'nearest', freq_axis_kernel_size: int = 1)
Bases: Module
Upsampling network module.
This module performs upsampling on input tensors through a series of interpolation and convolution layers. It allows for customizable non-linear activation functions and scaling factors.
- Parameters:
- upsample_scales (List *[*int ]) – List of upsampling scales.
- nonlinear_activation (Optional *[*str ]) – Activation function name.
- nonlinear_activation_params (Dict *[*str , Any ]) – Arguments for the specified activation function.
- interpolate_mode (str) – Interpolation mode for upsampling.
- freq_axis_kernel_size (int) – Kernel size in the direction of frequency axis.
####### Examples
>>> upsample_network = UpsampleNetwork(
... upsample_scales=[2, 2],
... nonlinear_activation='ReLU',
... nonlinear_activation_params={},
... interpolate_mode='nearest',
... freq_axis_kernel_size=3
... )
>>> input_tensor = torch.randn(1, 80, 50) # (B, C, T_feats)
>>> output_tensor = upsample_network(input_tensor)
>>> output_tensor.shape
torch.Size([1, 80, 200]) # (B, C, T_wav)
Initialize UpsampleNetwork module.
- Parameters:
- upsample_scales (List *[*int ]) – List of upsampling scales.
- nonlinear_activation (Optional *[*str ]) – Activation function name.
- nonlinear_activation_params (Dict *[*str , Any ]) – Arguments for the specified activation function.
- interpolate_mode (str) – Interpolation mode.
- freq_axis_kernel_size (int) – Kernel size in the direction of frequency axis.
forward(c: Tensor) → Tensor
Calculate forward propagation.
This method processes the input tensor through a series of upsampling layers and returns the upsampled tensor.
- Parameters:c (Tensor) – Input tensor of shape (B, C, T_feats), where:
- B: Batch size
- C: Number of channels
- T_feats: Number of feature dimensions
- Returns: Upsampled tensor of shape (B, C, T_wav), where: : - T_wav = T_feats * prod(upsample_scales), representing the total length after upsampling.
- Return type: Tensor
####### Examples
>>> model = UpsampleNetwork(upsample_scales=[2, 2])
>>> input_tensor = torch.randn(1, 10, 5) # Example input
>>> output_tensor = model(input_tensor)
>>> output_tensor.shape
torch.Size([1, 10, 20]) # Example output shape after upsampling