espnet2.gan_codec.shared.discriminator.stft_discriminator.ModReLU
espnet2.gan_codec.shared.discriminator.stft_discriminator.ModReLU
class espnet2.gan_codec.shared.discriminator.stft_discriminator.ModReLU
Bases: Module
Complex ReLU module.
This module applies a modified ReLU activation function to complex-valued inputs. It computes the ReLU of the absolute value of the input, adds a learnable bias, and then reconstructs the complex output using the original phase of the input.
Reference: : - https://arxiv.org/abs/1705.09792
b
Learnable parameter that is added to the absolute value of the input before applying the ReLU function.
Type: torch.Parameter
Parameters:None
Returns: The complex output after applying the modified ReLU.
Return type: Tensor
####### Examples
>>> mod_relu = ModReLU()
>>> input_tensor = torch.tensor([1+2j, -3-4j, 0+0j])
>>> output_tensor = mod_relu(input_tensor)
>>> print(output_tensor)
tensor([3.0 + 2.0j, 0.0 + 0.0j, 0.0 + 0.0j])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Calculate forward propagation.
This method takes an input signal and processes it through the layers of the ComplexSTFT Discriminator. The signal undergoes a Short-Time Fourier Transform (STFT) and is then passed through a series of complex convolutional layers. The output can be either the absolute values of the logits or the real part of the complex output, depending on the configuration.
- Parameters:
- x (Tensor) – Input signal of shape (B, 1, T), where B is the batch
- size
- channels (1 is the number of input)
- the (and T is the length of)
- signal.
- Returns: A nested list containing the discriminator output after processing through the network. The output is in the form of complex tensors.
- Return type: List[List[Tensor]]
Reference: : Paper: https://arxiv.org/pdf/2107.03312.pdf Implementation: https://github.com/alibaba-damo-academy/FunCodec.git
####### Examples
>>> model = ComplexSTFTDiscriminator()
>>> input_signal = torch.randn(8, 1, 1024) # Batch of 8 signals
>>> output = model(input_signal)
>>> print(len(output)) # Should print 1
>>> print(output[0][0].shape) # Shape of the output tensor
NOTE
Ensure that the input tensor x is appropriately shaped and contains the expected type of data (float or complex) for proper processing through the model.