espnet2.gan_codec.funcodec.funcodec.FunCodecGenerator
espnet2.gan_codec.funcodec.funcodec.FunCodecGenerator
class espnet2.gan_codec.funcodec.funcodec.FunCodecGenerator(sample_rate: int = 24000, hidden_dim: int = 128, codebook_dim: int = 8, encdec_channels: int = 1, encdec_n_filters: int = 32, encdec_n_residual_layers: int = 1, encdec_ratios: List[Tuple[int, int]] = [(4, 1), (4, 1), (4, 2), (4, 1)], encdec_activation: str = 'ELU', encdec_activation_params: Dict[str, Any] = {'alpha': 1.0}, encdec_norm: str = 'weight_norm', encdec_norm_params: Dict[str, Any] = {}, encdec_kernel_size: int = 7, encdec_residual_kernel_size: int = 7, encdec_last_kernel_size: int = 7, encdec_dilation_base: int = 2, encdec_causal: bool = False, encdec_pad_mode: str = 'reflect', encdec_true_skip: bool = False, encdec_compress: int = 2, encdec_lstm: int = 2, decoder_trim_right_ratio: float = 1.0, decoder_final_activation: str | None = None, decoder_final_activation_params: dict | None = None, quantizer_n_q: int = 8, quantizer_bins: int = 1024, quantizer_decay: float = 0.99, quantizer_kmeans_init: bool = True, quantizer_kmeans_iters: int = 50, quantizer_threshold_ema_dead_code: int = 2, quantizer_target_bandwidth: List[float] = [7.5, 15], quantizer_dropout: bool = True, codec_domain: List = ('time', 'time'), domain_conf: Dict | None = {}, audio_normalize: bool = False)
Bases: Module
FunCodec generator module.
This module defines the generator part of the FunCodec architecture, which is responsible for encoding and decoding audio signals. It utilizes various neural network components, including convolutional layers, quantization mechanisms, and domain transformation techniques to achieve efficient audio synthesis.
codec_domain
List indicating the codec domains for encoding and decoding (e.g., “time”, “stft”).
- Type: List
domain_conf
Configuration parameters for domain transformations.
- Type: Optional[Dict]
encoder
The encoder network used to process audio.
- Type:SEANetEncoder2d
quantizer
The quantization module to map encoder outputs to discrete codes.
decoder
The decoder network used to reconstruct audio from quantized codes.
- Type:SEANetDecoder2d
sample_rate
The sampling rate of the audio signals.
- Type: int
frame_rate
The frame rate for the processed audio.
- Type: int
audio_normalize
Flag indicating whether to normalize audio inputs.
Type: bool
Parameters:
- sample_rate (int) – Sampling rate of the audio. Default is 24000.
- hidden_dim (int) – Dimensionality of the hidden layers. Default is 128.
- codebook_dim (int) – Dimensionality of the codebook for quantization. Default is 8.
- encdec_channels (int) – Number of input/output channels for the encoder and decoder. Default is 1.
- encdec_n_filters (int) – Number of filters in the encoder/decoder layers. Default is 32.
- encdec_n_residual_layers (int) – Number of residual layers in the encoder/decoder. Default is 1.
- encdec_ratios (List *[*Tuple *[*int , int ] ]) – Ratios for downsampling/upsampling in the encoder/decoder. Default is [(4, 1), (4, 1), (4, 2), (4, 1)].
- encdec_activation (str) – Activation function to use. Default is “ELU”.
- encdec_activation_params (Dict *[*str , Any ]) – Parameters for the activation function. Default is {“alpha”: 1.0}.
- encdec_norm (str) – Normalization type to use in the layers. Default is “weight_norm”.
- encdec_norm_params (Dict *[*str , Any ]) – Parameters for normalization layers.
- encdec_kernel_size (int) – Kernel size for the convolutional layers. Default is 7.
- encdec_residual_kernel_size (int) – Kernel size for the residual layers. Default is 7.
- encdec_last_kernel_size (int) – Kernel size for the last layer. Default is 7.
- encdec_dilation_base (int) – Dilation base for the convolutional layers. Default is 2.
- encdec_causal (bool) – Flag indicating whether to use causal convolutions. Default is False.
- encdec_pad_mode (str) – Padding mode for the convolutional layers. Default is “reflect”.
- encdec_true_skip (bool) – Flag indicating whether to use true skip connections. Default is False.
- encdec_compress (int) – Compression factor for the encoder. Default is 2.
- encdec_lstm (int) – Number of LSTM layers in the encoder/decoder. Default is 2.
- decoder_trim_right_ratio (float) – Ratio to trim the right side of the decoder output. Default is 1.0.
- decoder_final_activation (Optional *[*str ]) – Final activation function for the decoder. Default is None.
- decoder_final_activation_params (Optional *[*dict ]) – Parameters for the final activation function. Default is None.
- quantizer_n_q (int) – Number of quantization channels. Default is 8.
- quantizer_bins (int) – Number of bins for quantization. Default is 1024.
- quantizer_decay (float) – Decay rate for the quantizer. Default is 0.99.
- quantizer_kmeans_init (bool) – Flag indicating whether to initialize quantizer with K-means. Default is True.
- quantizer_kmeans_iters (int) – Number of iterations for K-means initialization. Default is 50.
- quantizer_threshold_ema_dead_code (int) – Threshold for EMA dead code. Default is 2.
- quantizer_target_bandwidth (List *[*float ]) – Target bandwidths for quantization. Default is [7.5, 15].
- quantizer_dropout (bool) – Flag indicating whether to use dropout in quantization. Default is True.
- audio_normalize (bool) – Flag indicating whether to normalize audio inputs. Default is False.
############### Examples
>>> generator = FunCodecGenerator(sample_rate=24000)
>>> input_audio = torch.randn(1, 1, 24000) # Batch size 1, 1 channel
>>> resynthesized_audio, commit_loss, quantization_loss, resynthesized_real = generator(input_audio)
######## NOTE Ensure that the input audio is appropriately shaped as (B, 1, T) where B is the batch size and T is the length of the audio sequence.
Initialize FunCodec Generator.
- Parameters:TODO (jiatong)
decode(codes: Tensor)
Run decoding.
This method takes encoded input codes and generates the corresponding waveform output. The decoding process transforms the compressed codes back into audio waveform.
- Parameters:x (Tensor) – Input codes (T_code, N_stream), where T_code is the length of the code sequence and N_stream is the number of streams.
- Returns: Generated waveform (T_wav,), which is the reconstructed audio from the input codes.
- Return type: Tensor
############### Examples
>>> codec = FunCodec()
>>> input_codes = torch.randn(100, 8) # Example input codes
>>> generated_waveform = codec.decode(input_codes)
>>> print(generated_waveform.shape) # Should print a shape like (T_wav,)
encode(x: Tensor, target_bw: float | None = None)
Run encoding.
- Parameters:x (Tensor) – Input audio (T_wav,).
- Returns: Generated codes (T_code, N_stream).
- Return type: Tensor
############### Examples
>>> codec = FunCodec()
>>> input_audio = torch.randn(1, 24000) # Example audio tensor
>>> codes = codec.encode(input_audio)
>>> print(codes.shape) # Should output the shape of the generated codes
######## NOTE The input tensor x should have a shape of (B, T_wav), where B is the batch size and T_wav is the number of audio samples.
- Raises:ValueError – If the input tensor x is not of the expected shape.
forward(x: Tensor, use_dual_decoder: bool = False)
Perform generator forward.
This method executes the forward pass of the FunCodec model. Depending on the value of the forward_generator flag, it either computes the output from the generator or the discriminator.
- Parameters:
- audio (Tensor) – Audio waveform tensor of shape (B, T_wav), where B is the batch size and T_wav is the number of audio samples.
- forward_generator (bool) – Flag indicating whether to forward through the generator (True) or the discriminator (False).
- Returns:
- loss (Tensor): A scalar tensor representing the computed loss.
- stats (Dict[str, float]): A dictionary containing various statistics to be monitored during training, including losses.
- weight (Tensor): A tensor summarizing the weight of the loss based on the batch size.
- optim_idx (int): The optimizer index indicating which optimizer to use (0 for generator and 1 for discriminator).
- Return type: Dict[str, Any]
############### Examples
>>> model = FunCodec()
>>> audio_input = torch.randn(2, 22050) # Example input tensor
>>> output = model.forward(audio_input, forward_generator=True)
>>> print(output['loss'].item())
freq_to_time_transfer(x: Tensor, scale: Tensor | None = None)
Convert frequency domain representation back to time domain.
This method processes the input tensor x, which is assumed to be in the frequency domain, and converts it back to the time domain. The conversion is based on the codec domain defined during the initialization of the FunCodecGenerator. Additionally, a scale can be applied to the output, which is useful for normalization or restoration of audio signals.
- Parameters:
- x (torch.Tensor) – Input tensor in the frequency domain. The shape depends on the codec domain:
- For “stft”: (B, C, T) where C is 2 (real and imaginary).
- For “mag_phase”: (B, C, T) where C is 3 (magnitude and phase).
- For “mag_angle”: (B, C, T) where C is 2 (magnitude and angle).
- For “mag_oracle_phase”: (B, C, T) where C is 2 (magnitude and angle).
- scale (torch.Tensor , optional) – A tensor to scale the output. If provided, it should match the shape for proper broadcasting.
- x (torch.Tensor) – Input tensor in the frequency domain. The shape depends on the codec domain:
- Returns: The converted time domain tensor. : The shape will be (B, 1, T_wav), where T_wav is the length of the output waveform.
- Return type: torch.Tensor
######## NOTE
- The processing behavior may vary based on the codec domain settings.
- Ensure that the input tensor x has the correct shape for the specified codec domain to avoid runtime errors.
############### Examples
>>> generator = FunCodecGenerator()
>>> freq_tensor = torch.randn(1, 2, 512) # Example for STFT
>>> time_tensor = generator.freq_to_time_transfer(freq_tensor)
>>> print(time_tensor.shape)
torch.Size([1, 1, T_wav]) # T_wav will depend on the inverse process
>>> scale_tensor = torch.tensor([0.5])
>>> time_tensor_scaled = generator.freq_to_time_transfer(freq_tensor, scale_tensor)
time_to_freq_transfer(x: Tensor)
Convert time-domain audio signals to frequency-domain representations.
This method transforms the input audio tensor x from the time domain to the frequency domain based on the codec domain configuration. It handles different types of frequency representations such as STFT, magnitude, and phase. It also includes optional audio normalization.
- Parameters:x (torch.Tensor) – Input audio tensor of shape (B, C, T), where B is the batch size, C is the number of channels, and T is the number of time steps.
- Returns:
- x (torch.Tensor): The transformed frequency-domain tensor.
- scale (Optional[torch.Tensor]): Scale tensor used for normalization, if applicable.
- Return type: Tuple[torch.Tensor, Optional[torch.Tensor]]
######## NOTE The method modifies the input tensor x based on the codec domain specified during initialization. The input tensor can be normalized to maintain a consistent volume level across audio samples.
############### Examples
>>> audio_tensor = torch.randn(2, 1, 16000) # Batch of 2 audio samples
>>> transformed_x, scale = time_to_freq_transfer(audio_tensor)
>>> print(transformed_x.shape) # Output shape depends on codec domain