espnet2.gan_tts.parallel_wavegan.upsample.Stretch2d
espnet2.gan_tts.parallel_wavegan.upsample.Stretch2d
class espnet2.gan_tts.parallel_wavegan.upsample.Stretch2d(x_scale: int, y_scale: int, mode: str = 'nearest')
Bases: Module
Stretch2d module.
This module performs 2D stretching (upsampling) on input tensors, typically used in audio processing applications such as spectrogram manipulation.
The code is modified from https://github.com/kan-bayashi/ParallelWaveGAN.
x_scale
X scaling factor (Time axis in spectrogram).
- Type: int
y_scale
Y scaling factor (Frequency axis in spectrogram).
- Type: int
mode
Interpolation mode for upsampling.
Type: str
Parameters:
- x_scale (int) – X scaling factor (Time axis in spectrogram).
- y_scale (int) – Y scaling factor (Frequency axis in spectrogram).
- mode (str) – Interpolation mode.
Returns: None
####### Examples
>>> import torch
>>> stretch = Stretch2d(x_scale=2, y_scale=3, mode='nearest')
>>> input_tensor = torch.randn(1, 1, 4, 4) # (B, C, F, T)
>>> output_tensor = stretch(input_tensor)
>>> output_tensor.shape
torch.Size([1, 1, 12, 8]) # (B, C, F * y_scale, T * x_scale)
Initialize Stretch2d module.
- Parameters:
- x_scale (int) – X scaling factor (Time axis in spectrogram).
- y_scale (int) – Y scaling factor (Frequency axis in spectrogram).
- mode (str) – Interpolation mode.
forward(x: Tensor) → Tensor
Calculate forward propagation.
This method performs the forward pass of the Stretch2d module by applying upsampling to the input tensor using the specified scaling factors and interpolation mode.
- Parameters:x (Tensor) – Input tensor of shape (B, C, F, T), where: B - Batch size C - Number of channels F - Frequency bins T - Time steps
- Returns: Interpolated tensor of shape (B, C, F * y_scale, T * x_scale), : where y_scale and x_scale are the scaling factors defined during initialization.
- Return type: Tensor
####### Examples
>>> stretch = Stretch2d(x_scale=2, y_scale=3, mode='linear')
>>> input_tensor = torch.randn(1, 1, 4, 4) # Example input
>>> output_tensor = stretch(input_tensor)
>>> output_tensor.shape
torch.Size([1, 1, 12, 8]) # Output shape after upsampling