espnet2.gan_codec.funcodec.funcodec.FunCodec
espnet2.gan_codec.funcodec.funcodec.FunCodec
class espnet2.gan_codec.funcodec.funcodec.FunCodec(sampling_rate: int = 24000, generator_params: Dict[str, Any] = {'codec_domain': ['time', 'time'], 'decoder_final_activation': None, 'decoder_final_activation_params': None, 'decoder_trim_right_ratio': 1.0, 'domain_conf': {}, 'encdec_activation': 'ELU', 'encdec_activation_params': {'alpha': 1.0}, 'encdec_causal': False, 'encdec_channels': 1, 'encdec_compress': 2, 'encdec_dilation_base': 2, 'encdec_kernel_size': 7, 'encdec_last_kernel_size': 7, 'encdec_lstm': 2, 'encdec_n_filters': 32, 'encdec_n_residual_layers': 1, 'encdec_norm': 'weight_norm', 'encdec_norm_params': {}, 'encdec_pad_mode': 'reflect', 'encdec_ratios': [(8, 1), (5, 1), (4, 1), (2, 1)], 'encdec_residual_kernel_size': 7, 'encdec_true_skip': False, 'hidden_dim': 128, 'quantizer_bins': 1024, 'quantizer_decay': 0.99, 'quantizer_dropout': True, 'quantizer_kmeans_init': True, 'quantizer_kmeans_iters': 50, 'quantizer_n_q': 8, 'quantizer_target_bandwidth': [7.5, 15], 'quantizer_threshold_ema_dead_code': 2}, discriminator_params: Dict[str, Any] = {'complexstft_discriminator_params': {'chan_mults': (1, 2, 4, 4, 8, 8), 'channels': 32, 'hop_length': 256, 'in_channels': 1, 'logits_abs': True, 'n_fft': 1024, 'stft_normalized': False, 'strides': ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), 'win_length': 1024}, 'period_discriminator_params': {'bias': True, 'channels': 32, 'downsample_scales': [3, 3, 3, 3, 1], 'in_channels': 1, 'kernel_sizes': [5, 3], 'max_downsample_channels': 1024, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, 'periods': [2, 3, 5, 7, 11], 'scale_discriminator_params': {'bias': True, 'channels': 128, 'downsample_scales': [2, 2, 4, 4, 1], 'in_channels': 1, 'kernel_sizes': [15, 41, 5, 3], 'max_downsample_channels': 1024, 'max_groups': 16, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1}, 'scale_downsample_pooling': 'AvgPool1d', 'scale_downsample_pooling_params': {'kernel_size': 4, 'padding': 2, 'stride': 2}, 'scale_follow_official_norm': False, 'scales': 3}, generator_adv_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'loss_type': 'mse'}, discriminator_adv_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'loss_type': 'mse'}, use_feat_match_loss: bool = True, feat_match_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'average_by_layers': False, 'include_final_outputs': True}, use_mel_loss: bool = True, mel_loss_params: Dict[str, Any] = {'fmax': None, 'fmin': 0, 'fs': 24000, 'log_base': None, 'n_mels': 80, 'range_end': 11, 'range_start': 6, 'window': 'hann'}, use_dual_decoder: bool = False, lambda_quantization: float = 1.0, lambda_reconstruct: float = 1.0, lambda_commit: float = 1.0, lambda_adv: float = 1.0, lambda_feat_match: float = 2.0, lambda_mel: float = 45.0, cache_generator_outputs: bool = False)
Bases: AbsGANCodec
FunCodec model for audio processing using GAN architecture.
This model implements a GAN-based codec that performs audio encoding and decoding. It consists of a generator for generating audio waveforms and a discriminator for evaluating the generated audio quality. The FunCodec model is capable of applying various loss functions to improve the quality of the generated audio.
generator
The generator module.
- Type:FunCodecGenerator
discriminator
The discriminator module.
generator
Adversarial loss for the generator.
generator
Reconstruction loss for the generator.
- Type: torch.nn.L1Loss
discriminator
Adversarial loss for the discriminator.
use_feat_match_loss
Flag to indicate whether to use feature matching loss.
- Type: bool
feat_match_loss
Feature matching loss module.
- Type:FeatureMatchLoss
use_mel_loss
Flag to indicate whether to use mel loss.
- Type: bool
mel_loss
Mel loss module.
use_dual_decoder
Flag to indicate whether to use dual decoder.
- Type: bool
lambda_quantization
Weight for quantization loss.
- Type: float
lambda_reconstruct
Weight for reconstruction loss.
- Type: float
lambda_commit
Weight for commitment loss.
- Type: float
lambda_adv
Weight for adversarial loss.
- Type: float
lambda_feat_match
Weight for feature matching loss.
- Type: float
lambda_mel
Weight for mel loss.
- Type: float
cache_generator_outputs
Flag to cache generator outputs.
- Type: bool
fs
Sampling rate for saving audio files.
- Type: int
num_streams
Number of streams for quantization.
- Type: int
frame_shift
Frame shift calculated from encoder ratios.
- Type: int
code_size_per_stream
Code size for each stream.
Type: List[int]
Parameters:
- sampling_rate (int) – Sampling rate for the audio (default: 24000).
- generator_params (Dict *[*str , Any ]) – Parameters for the generator.
- discriminator_params (Dict *[*str , Any ]) – Parameters for the discriminator.
- generator_adv_loss_params (Dict *[*str , Any ]) – Parameters for generator adversarial loss.
- discriminator_adv_loss_params (Dict *[*str , Any ]) – Parameters for discriminator adversarial loss.
- use_feat_match_loss (bool) – Whether to use feature matching loss (default: True).
- feat_match_loss_params (Dict *[*str , Any ]) – Parameters for feature matching loss.
- use_mel_loss (bool) – Whether to use mel loss (default: True).
- mel_loss_params (Dict *[*str , Any ]) – Parameters for mel loss.
- use_dual_decoder (bool) – Whether to use dual decoder (default: False).
- lambda_quantization (float) – Weight for quantization loss (default: 1.0).
- lambda_reconstruct (float) – Weight for reconstruction loss (default: 1.0).
- lambda_commit (float) – Weight for commitment loss (default: 1.0).
- lambda_adv (float) – Weight for adversarial loss (default: 1.0).
- lambda_feat_match (float) – Weight for feature matching loss (default: 2.0).
- lambda_mel (float) – Weight for mel loss (default: 45.0).
- cache_generator_outputs (bool) – Flag to cache generator outputs (default: False).
############### Examples
>>> codec = FunCodec()
>>> audio_input = torch.randn(1, 24000) # Example input tensor
>>> output = codec.forward(audio_input)
>>> print(output['loss']) # Access the computed loss
####### NOTE Ensure that the input audio tensor has the shape (B, T_wav) where B is the batch size and T_wav is the length of the audio waveform.
- Raises:AssertionError – If dual decoder is enabled without mel loss.
Intialize FunCodec model.
- Parameters:TODO (jiatong)
decode(x: Tensor, **kwargs) → Tensor
Run decoding.
This method takes the encoded input codes and generates a waveform tensor as output.
- Parameters:x (Tensor) – Input codes (T_code, N_stream).
- Returns: Generated waveform (T_wav,).
- Return type: Tensor
############### Examples
>>> model = FunCodec()
>>> codes = torch.randn(100, 8) # Example input codes
>>> waveform = model.decode(codes)
>>> print(waveform.shape)
torch.Size([B, T_wav]) # The shape will depend on the model's
configuration and input codes.
encode(x: Tensor, **kwargs) → Tensor
Run encoding.
- Parameters:x (Tensor) – Input audio (T_wav,).
- Returns: Generated codes (T_code, N_stream).
- Return type: Tensor
############### Examples
>>> model = FunCodec()
>>> audio_input = torch.randn(1, 24000) # Example audio tensor
>>> codes = model.encode(audio_input)
>>> print(codes.shape) # Output shape will depend on model config
####### NOTE This method utilizes the generator’s encoding capabilities to transform audio waveforms into a compressed representation.
forward(audio: Tensor, forward_generator: bool = True, **kwargs) → Dict[str, Any]
Perform generator forward.
- Parameters:
- audio (Tensor) – Audio waveform tensor (B, T_wav).
- forward_generator (bool) – Whether to forward generator.
- Returns:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
- Return type: Dict[str, Any]
############### Examples
>>> model = FunCodec()
>>> audio_input = torch.randn(8, 220500) # Example input
>>> output = model.forward(audio_input)
>>> print(output.keys())
dict_keys(['loss', 'stats', 'weight', 'optim_idx'])
####### NOTE This method determines whether to use the generator or the discriminator for the forward pass based on the forward_generator argument.
inference(x: Tensor, **kwargs) → Dict[str, Tensor]
FunCodec model for audio generation and encoding.
This model implements a generative audio codec based on GAN architecture. It is designed to encode audio signals into a compact representation and decode them back to waveform format.
generator
The generator module of the model.
- Type:FunCodecGenerator
discriminator
The discriminator module of the model.
generator
Adversarial loss for the generator.
generator
Reconstruction loss for the generator.
- Type: torch.nn.L1Loss
discriminator
Adversarial loss for the discriminator.
use_feat_match_loss
Flag to use feature matching loss.
- Type: bool
feat_match_loss
Feature matching loss module.
- Type:FeatureMatchLoss
use_mel_loss
Flag to use mel loss.
- Type: bool
mel_loss
Mel spectrogram loss module.
use_dual_decoder
Flag to indicate if dual decoder is used.
- Type: bool
cache_generator_outputs
Flag to cache generator outputs.
- Type: bool
fs
Sampling rate of the audio.
- Type: int
num_streams
Number of streams in the quantizer.
- Type: int
frame_shift
Frame shift size calculated from encoder ratios.
- Type: int
code_size_per_stream
Code size for each stream.
Type: List[int]
Parameters:
- sampling_rate (int) – The sampling rate of the audio (default: 24000).
- generator_params (Dict *[*str , Any ]) – Parameters for the generator module.
- discriminator_params (Dict *[*str , Any ]) – Parameters for the discriminator module.
- generator_adv_loss_params (Dict *[*str , Any ]) – Parameters for generator adversarial loss.
- discriminator_adv_loss_params (Dict *[*str , Any ]) – Parameters for discriminator adversarial loss.
- use_feat_match_loss (bool) – Flag to use feature matching loss (default: True).
- feat_match_loss_params (Dict *[*str , Any ]) – Parameters for feature matching loss.
- use_mel_loss (bool) – Flag to use mel loss (default: True).
- mel_loss_params (Dict *[*str , Any ]) – Parameters for mel loss.
- use_dual_decoder (bool) – Flag to indicate if dual decoder is used (default: False).
- lambda_quantization (float) – Weight for quantization loss (default: 1.0).
- lambda_reconstruct (float) – Weight for reconstruction loss (default: 1.0).
- lambda_commit (float) – Weight for commitment loss (default: 1.0).
- lambda_adv (float) – Weight for adversarial loss (default: 1.0).
- lambda_feat_match (float) – Weight for feature matching loss (default: 2.0).
- lambda_mel (float) – Weight for mel loss (default: 45.0).
- cache_generator_outputs (bool) – Flag to cache generator outputs (default: False).
############### Examples
Creating a FunCodec model instance
codec = FunCodec(sampling_rate=22050)
Forward pass with audio input
audio_input = torch.randn(1, 22050) # Simulated audio output = codec(audio_input)
meta_info() → Dict[str, Any]
Retrieve metadata information of the FunCodec model.
This method returns essential information about the current configuration of the FunCodec model, including the sampling rate, number of streams, frame shift, and code size per stream. This information can be useful for understanding the model’s parameters and for debugging purposes.
- Returns: A dictionary containing the following keys: : - ’fs’ (int): The sampling rate of the audio.
- ’num_streams’ (int): The number of streams in the codec.
- ’frame_shift’ (int): The frame shift size used in the model.
- ’code_size_per_stream’ (List[int]): A list indicating the code size for each stream.
- Return type: Dict[str, Any]
############### Examples
>>> model = FunCodec()
>>> info = model.meta_info()
>>> print(info)
{'fs': 24000, 'num_streams': 8, 'frame_shift': 128,
'code_size_per_stream': [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]}