espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexConv2d
espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexConv2d
class espnet2.gan_codec.shared.discriminator.stft_discriminator.ComplexConv2d(dim, dim_out, kernel_size, stride=1, padding=0)
Bases: Module
ComplexConv2d module that performs complex-valued 2D convolution.
This module extends the standard 2D convolution to handle complex-valued inputs. The weights and biases are stored in real-valued format and are converted to complex format during the forward pass.
weight
The complex convolution weights as real-valued tensor.
- Type: torch.Tensor
bias
The complex convolution biases as real-valued tensor.
- Type: torch.Tensor
stride
The stride of the convolution.
- Type: int
padding
The padding applied to the input.
Type: int
Parameters:
- dim (int) – Number of input channels.
- dim_out (int) – Number of output channels.
- kernel_size (int or tuple) – Size of the convolution kernel.
- stride (int or tuple , optional) – Stride of the convolution. Default is 1.
- padding (int or tuple , optional) – Padding added to both sides of the input. Default is 0.
Returns: The result of the complex 2D convolution.
Return type: Tensor
####### Examples
>>> import torch
>>> conv = ComplexConv2d(dim=2, dim_out=3, kernel_size=3)
>>> input_tensor = torch.randn(1, 2, 10, 10, dtype=torch.complex64)
>>> output = conv(input_tensor)
>>> print(output.shape)
torch.Size([1, 3, 8, 8]) # Output shape depends on padding and stride
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Calculate forward propagation.
This method processes the input signal through the Complex STFT Discriminator using complex convolutional layers and performs the Short-Time Fourier Transform (STFT) on the input.
- Parameters:x (Tensor) – Input signal with shape (B, 1, T), where B is the batch size, and T is the length of the signal.
- Returns: A nested list containing the output of the : discriminator after processing the input signal through all layers.
- Return type: List[List[Tensor]]
Reference: : Paper: https://arxiv.org/pdf/2107.03312.pdf Implementation: https://github.com/alibaba-damo-academy/FunCodec.git
####### Examples
>>> discriminator = ComplexSTFTDiscriminator()
>>> input_signal = torch.randn(4, 1, 1024) # Batch of 4 signals
>>> output = discriminator(input_signal)
>>> print(len(output)) # Should print the number of layers
>>> print(output[0][0].shape) # Shape of the output tensor