espnet2.gan_codec.hificodec.hificodec.HiFiCodec
espnet2.gan_codec.hificodec.hificodec.HiFiCodec
class espnet2.gan_codec.hificodec.hificodec.HiFiCodec(sampling_rate: int = 16000, generator_params: Dict[str, Any] = {'hidden_dim': 256, 'quantizer_bins': 1024, 'quantizer_decay': 0.99, 'quantizer_kmeans_init': True, 'quantizer_kmeans_iters': 50, 'quantizer_n_q': 8, 'quantizer_target_bandwidth': [7.5, 15], 'quantizer_threshold_ema_dead_code': 2, 'resblock_dilation_sizes': [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 'resblock_kernel_sizes': [3, 7, 11], 'resblock_num': '1', 'upsample_initial_channel': 512, 'upsample_kernel_sizes': [16, 11, 8, 4], 'upsample_rates': [8, 5, 4, 2]}, discriminator_params: Dict[str, Any] = {'msstft_discriminator_params': {'activation': 'LeakyReLU', 'activation_params': {'negative_slope': 0.2}, 'filters': 32, 'hop_lengths': [256, 512, 128, 64, 32], 'in_channels': 1, 'n_ffts': [1024, 2048, 512, 256, 128], 'norm': 'weight_norm', 'out_channels': 1, 'win_lengths': [1024, 2048, 512, 256, 128]}, 'periods': [2, 3, 5, 7, 11], 'periods_discriminator_params': {'bias': False, 'channels': 32, 'downsample_scales': [3, 3, 3, 3, 1], 'in_channels': 1, 'kernel_sizes': [5, 3], 'max_downsample_channels': 1024, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, 'scale_discriminator_params': {'bias': False, 'channels': 128, 'downsample_scales': [2, 2, 4, 4, 1], 'in_channels': 1, 'kernel_sizes': [15, 41, 5, 3], 'max_downsample_channels': 1024, 'max_groups': 16, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1, 'use_spectral_norm': False, 'use_weight_norm': True}, 'scale_downsample_pooling': 'AvgPool1d', 'scale_downsample_pooling_params': {'kernel_size': 4, 'padding': 2, 'stride': 2}, 'scale_follow_official_norm': False, 'scales': 3}, generator_adv_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'loss_type': 'mse'}, discriminator_adv_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'loss_type': 'mse'}, use_feat_match_loss: bool = True, feat_match_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'average_by_layers': False, 'include_final_outputs': True}, use_mel_loss: bool = True, mel_loss_params: Dict[str, Any] = {'fmax': None, 'fmin': 0, 'fs': 16000, 'log_base': None, 'n_mels': 80, 'range_end': 11, 'range_start': 6, 'window': 'hann'}, use_dual_decoder: bool = True, lambda_quantization: float = 1.0, lambda_reconstruct: float = 1.0, lambda_commit: float = 1.0, lambda_adv: float = 1.0, lambda_feat_match: float = 2.0, lambda_mel: float = 45.0, cache_generator_outputs: bool = False, use_loss_balancer: bool = False, balance_ema_decay: float = 0.99)
Bases: AbsGANCodec
HiFiCodec model for high-fidelity audio generation and encoding.
This model implements a GAN-based codec for generating high-quality audio. It consists of a generator and discriminator, with options for various loss functions and features for improved audio fidelity.
generator
The generator module for audio synthesis.
- Type:HiFiCodecGenerator
discriminator
The discriminator module for evaluating the quality of generated audio.
generator
Adversarial loss for the generator.
generator
Reconstruction loss for generator.
- Type: L1Loss
discriminator
Adversarial loss for the discriminator.
use_feat_match_loss
Flag to indicate whether to use feature matching loss.
- Type: bool
feat_match_loss
Feature matching loss if enabled.
- Type:FeatureMatchLoss
use_mel_loss
Flag to indicate whether to use mel spectrogram loss.
- Type: bool
mel_loss
Mel spectrogram loss if enabled.
use_dual_decoder
Flag to indicate whether to use a dual decoder.
- Type: bool
cache_generator_outputs
Flag to indicate whether to cache generator outputs for efficiency.
- Type: bool
loss_balancer
Loss balancer for adjusting the weights of different loss components.
Type:Balancer
Parameters:
- sampling_rate (int) – The sampling rate of the audio. Default is 16000.
- generator_params (Dict *[*str , Any ]) – Parameters for the generator module.
- discriminator_params (Dict *[*str , Any ]) – Parameters for the discriminator module.
- generator_adv_loss_params (Dict *[*str , Any ]) – Parameters for generator adversarial loss.
- discriminator_adv_loss_params (Dict *[*str , Any ]) – Parameters for discriminator adversarial loss.
- use_feat_match_loss (bool) – Whether to use feature matching loss.
- feat_match_loss_params (Dict *[*str , Any ]) – Parameters for feature matching loss.
- use_mel_loss (bool) – Whether to use mel loss.
- mel_loss_params (Dict *[*str , Any ]) – Parameters for mel loss.
- use_dual_decoder (bool) – Whether to use dual decoder.
- lambda_quantization (float) – Weight for quantization loss. Default is 1.0.
- lambda_reconstruct (float) – Weight for reconstruction loss. Default is 1.0.
- lambda_commit (float) – Weight for commitment loss. Default is 1.0.
- lambda_adv (float) – Weight for adversarial loss. Default is 1.0.
- lambda_feat_match (float) – Weight for feature matching loss. Default is 2.0.
- lambda_mel (float) – Weight for mel loss. Default is 45.0.
- cache_generator_outputs (bool) – Whether to cache generator outputs.
- use_loss_balancer (bool) – Whether to use loss balancer.
- balance_ema_decay (float) – EMA decay rate for loss balancer. Default is 0.99.
############### Examples
>>> model = HiFiCodec()
>>> audio_input = torch.randn(1, 16000) # Simulated audio input
>>> output = model.forward(audio_input)
>>> print(output['loss']) # Access the loss from output
######### NOTE The generator and discriminator parameters can be adjusted for specific use cases, such as changing the number of layers or kernel sizes.
- Raises:AssertionError – If dual decoder is enabled without mel loss.
Intialize HiFiCodec model.
decode(x: Tensor, **kwargs) → Tensor
Run decoding.
This method takes encoded audio codes as input and generates a waveform tensor. The decoding process involves the reconstruction of the original audio from its encoded representation.
- Parameters:x (Tensor) – Input codes (T_code, N_stream), where T_code is the length of the code sequence and N_stream is the number of streams in the codec.
- Returns: Generated waveform (T_wav,), where T_wav is the length of the reconstructed audio waveform.
- Return type: Tensor
############### Examples
>>> codec_input = torch.randn(100, 8) # Example input codes
>>> waveform = hi_fi_codec.decode(codec_input)
>>> print(waveform.shape)
torch.Size([T_wav,]) # Output shape will depend on the model
encode(x: Tensor, **kwargs) → Tensor
Run encoding.
- Parameters:x (Tensor) – Input audio (T_wav,).
- Returns: Generated codes (T_code, N_stream).
- Return type: Tensor
############### Examples
>>> model = HiFiCodec()
>>> audio_input = torch.randn(1, 16000) # Simulated audio input
>>> codes = model.encode(audio_input)
>>> print(codes.shape) # Expected output shape: (T_code, N_stream)
######### NOTE This method utilizes the generator’s encode function to process the input audio and produce a set of neural codec representations.
forward(audio: Tensor, forward_generator: bool = True, **kwargs) → Dict[str, Any]
Perform generator forward.
This method performs the forward pass of the HiFiCodec model. It can either execute the generator or discriminator forward pass based on the forward_generator flag. The output includes loss metrics and statistics that are useful for monitoring training progress.
- Parameters:
- audio (torch.Tensor) – Audio waveform tensor with shape (B, T_wav), where B is the batch size and T_wav is the number of audio samples.
- forward_generator (bool) – Flag to indicate whether to forward through the generator (True) or the discriminator (False).
- Returns:
- loss (Tensor): Loss scalar tensor representing the total computed loss.
- stats (Dict[str, float]): Dictionary containing various statistics to be monitored during training, such as individual loss components.
- weight (Tensor): Weight tensor used to summarize losses.
- optim_idx (int): Index of the optimizer to be used (0 for generator and 1 for discriminator).
- Return type: Dict[str, Any]
############### Examples
>>> model = HiFiCodec()
>>> audio_input = torch.randn(2, 16000) # Batch of 2 audio samples
>>> output = model.forward(audio_input, forward_generator=True)
>>> print(output['loss'], output['stats'])
######### NOTE Ensure that the input audio tensor is properly preprocessed and has the correct shape before calling this method.
inference(x: Tensor, **kwargs) → Dict[str, Tensor]
HiFiCodec model.
This class implements the HiFiCodec architecture, which is a generative adversarial network (GAN) designed for high-fidelity audio generation.
generator
The generator module for audio synthesis.
- Type:HiFiCodecGenerator
discriminator
The discriminator module for evaluating the generated audio.
generator
Adversarial loss for the generator.
generator
Reconstruction loss for the generator.
- Type: torch.nn.L1Loss
discriminator
Adversarial loss for the discriminator.
use_feat_match_loss
Flag to indicate whether to use feature matching loss.
- Type: bool
feat_match_loss
Feature matching loss module.
- Type:FeatureMatchLoss
use_mel_loss
Flag to indicate whether to use mel loss.
- Type: bool
mel_loss
Mel loss module.
cache_generator_outputs
Flag to cache generator outputs.
- Type: bool
loss_balancer
Loss balancer for managing multiple loss components.
Type: Optional[Balancer]
Parameters:
- sampling_rate (int) – Sampling rate for audio (default: 16000).
- generator_params (Dict *[*str , Any ]) – Parameters for the generator module.
- discriminator_params (Dict *[*str , Any ]) – Parameters for the discriminator module.
- generator_adv_loss_params (Dict *[*str , Any ]) – Parameters for the generator adversarial loss.
- discriminator_adv_loss_params (Dict *[*str , Any ]) – Parameters for the discriminator adversarial loss.
- use_feat_match_loss (bool) – Flag to use feature matching loss (default: True).
- feat_match_loss_params (Dict *[*str , Any ]) – Parameters for feature matching loss.
- use_mel_loss (bool) – Flag to use mel loss (default: True).
- mel_loss_params (Dict *[*str , Any ]) – Parameters for mel loss.
- use_dual_decoder (bool) – Flag to use dual decoder (default: True).
- lambda_quantization (float) – Coefficient for quantization loss (default: 1.0).
- lambda_reconstruct (float) – Coefficient for reconstruction loss (default: 1.0).
- lambda_commit (float) – Coefficient for commitment loss (default: 1.0).
- lambda_adv (float) – Coefficient for adversarial loss (default: 1.0).
- lambda_feat_match (float) – Coefficient for feature matching loss (default: 2.0).
- lambda_mel (float) – Coefficient for mel loss (default: 45.0).
- cache_generator_outputs (bool) – Flag to cache generator outputs (default: False).
- use_loss_balancer (bool) – Flag to use loss balancer (default: False).
- balance_ema_decay (float) – Exponential moving average decay for balancing loss (default: 0.99).
############### Examples
Initialize the HiFiCodec model with default parameters
hifi_codec = HiFiCodec()
Initialize the HiFiCodec model with custom parameters
custom_hifi_codec = HiFiCodec(sampling_rate=22050,
generator_params={“hidden_dim”: 512})
######### NOTE The model expects audio input in the form of a PyTorch tensor of shape (B, T_wav) for training and inference.
meta_info() → Dict[str, Any]
Retrieve model meta-information.
This method provides essential information about the HiFiCodec model’s configuration, including the sampling rate, number of streams, frame shift, and code size per stream.
- Returns: A dictionary containing the following keys: : - fs (int): The sampling rate of the model.
- num_streams (int): The number of quantizer streams used in the model.
- frame_shift (int): The frame shift calculated based on the upsample rates.
- code_size_per_stream (List[int]): A list containing the code size for each quantization stream.
- Return type: Dict[str, Any]
############### Examples
>>> model = HiFiCodec()
>>> meta = model.meta_info()
>>> print(meta)
{'fs': 16000, 'num_streams': 8, 'frame_shift': 640,
'code_size_per_stream': [1024, 1024, 1024, 1024, 1024, 1024,
1024, 1024]}
######### NOTE This method is useful for understanding the configuration of the model and for debugging purposes.