espnet2.gan_codec.soundstream.soundstream.SoundStream
espnet2.gan_codec.soundstream.soundstream.SoundStream
class espnet2.gan_codec.soundstream.soundstream.SoundStream(sampling_rate: int = 24000, generator_params: Dict[str, Any] = {'decoder_final_activation': None, 'decoder_final_activation_params': None, 'decoder_trim_right_ratio': 1.0, 'encdec_activation': 'ELU', 'encdec_activation_params': {'alpha': 1.0}, 'encdec_causal': False, 'encdec_channels': 1, 'encdec_compress': 2, 'encdec_dilation_base': 2, 'encdec_kernel_size': 7, 'encdec_last_kernel_size': 7, 'encdec_lstm': 2, 'encdec_n_filters': 32, 'encdec_n_residual_layers': 1, 'encdec_norm': 'weight_norm', 'encdec_norm_params': {}, 'encdec_pad_mode': 'reflect', 'encdec_ratios': [8, 5, 4, 2], 'encdec_residual_kernel_size': 7, 'encdec_true_skip': False, 'hidden_dim': 128, 'quantizer_bins': 1024, 'quantizer_decay': 0.99, 'quantizer_kmeans_init': True, 'quantizer_kmeans_iters': 50, 'quantizer_n_q': 8, 'quantizer_target_bandwidth': [7.5, 15], 'quantizer_threshold_ema_dead_code': 2}, discriminator_params: Dict[str, Any] = {'complexstft_discriminator_params': {'chan_mults': (1, 2, 4, 4, 8, 8), 'channels': 32, 'hop_length': 256, 'in_channels': 1, 'logits_abs': True, 'n_fft': 1024, 'stft_normalized': False, 'strides': ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), 'win_length': 1024}, 'scale_discriminator_params': {'bias': True, 'channels': 128, 'downsample_scales': [2, 2, 4, 4, 1], 'in_channels': 1, 'kernel_sizes': [15, 41, 5, 3], 'max_downsample_channels': 1024, 'max_groups': 16, 'nonlinear_activation': 'LeakyReLU', 'nonlinear_activation_params': {'negative_slope': 0.1}, 'out_channels': 1}, 'scale_downsample_pooling': 'AvgPool1d', 'scale_downsample_pooling_params': {'kernel_size': 4, 'padding': 2, 'stride': 2}, 'scale_follow_official_norm': False, 'scales': 3}, generator_adv_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'loss_type': 'mse'}, discriminator_adv_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'loss_type': 'mse'}, use_feat_match_loss: bool = True, feat_match_loss_params: Dict[str, Any] = {'average_by_discriminators': False, 'average_by_layers': False, 'include_final_outputs': True}, use_mel_loss: bool = True, mel_loss_params: Dict[str, Any] = {'fmax': None, 'fmin': 0, 'fs': 24000, 'log_base': None, 'n_mels': 80, 'range_end': 11, 'range_start': 6, 'window': 'hann'}, use_dual_decoder: bool = True, lambda_quantization: float = 1.0, lambda_reconstruct: float = 1.0, lambda_commit: float = 1.0, lambda_adv: float = 1.0, lambda_feat_match: float = 2.0, lambda_mel: float = 45.0, cache_generator_outputs: bool = False, use_loss_balancer: bool = False, balance_ema_decay: float = 0.99)
Bases: AbsGANCodec
SoundStream model for audio generation and processing.
This class implements the SoundStream model, which is a generative model for audio processing. It includes a generator and a discriminator, both of which are designed to work with audio waveforms. The model can perform tasks such as encoding, decoding, and generating audio waveforms.
generator
The generator component of the model.
- Type:SoundStreamGenerator
discriminator
The discriminator component of the model.
generator
Adversarial loss for the generator.
generator
Reconstruction loss for the generator.
- Type: torch.nn.L1Loss
discriminator
Adversarial loss for the discriminator.
use_feat_match_loss
Flag indicating whether to use feature matching loss.
- Type: bool
feat_match_loss
Feature matching loss module.
- Type:FeatureMatchLoss
use_mel_loss
Flag indicating whether to use mel loss.
- Type: bool
mel_loss
Mel spectrogram loss module.
cache_generator_outputs
Flag indicating whether to cache generator outputs.
- Type: bool
fs
Sampling rate for saving audio files.
- Type: int
num_streams
Number of quantization streams.
- Type: int
frame_shift
Frame shift size.
- Type: int
code_size_per_stream
Size of codes per quantization stream.
- Type: List[int]
loss_balancer
Loss balancer for handling multiple losses.
Type: Optional[Balancer]
Parameters:
- sampling_rate (int) – Sampling rate for audio processing. Default is 24000.
- generator_params (Dict *[*str , Any ]) – Parameters for the generator.
- discriminator_params (Dict *[*str , Any ]) – Parameters for the discriminator.
- generator_adv_loss_params (Dict *[*str , Any ]) – Parameters for generator adversarial loss.
- discriminator_adv_loss_params (Dict *[*str , Any ]) – Parameters for discriminator adversarial loss.
- use_feat_match_loss (bool) – Flag to use feature matching loss. Default is True.
- feat_match_loss_params (Dict *[*str , Any ]) – Parameters for feature matching loss.
- use_mel_loss (bool) – Flag to use mel loss. Default is True.
- mel_loss_params (Dict *[*str , Any ]) – Parameters for mel loss.
- use_dual_decoder (bool) – Flag to use dual decoder. Default is True.
- lambda_quantization (float) – Weight for quantization loss. Default is 1.0.
- lambda_reconstruct (float) – Weight for reconstruction loss. Default is 1.0.
- lambda_commit (float) – Weight for commitment loss. Default is 1.0.
- lambda_adv (float) – Weight for adversarial loss. Default is 1.0.
- lambda_feat_match (float) – Weight for feature matching loss. Default is 2.0.
- lambda_mel (float) – Weight for mel loss. Default is 45.0.
- cache_generator_outputs (bool) – Flag to cache generator outputs. Default is False.
- use_loss_balancer (bool) – Flag to use loss balancer. Default is False.
- balance_ema_decay (float) – Exponential moving average decay for balancing losses. Default is 0.99.
############### Examples
Initialize the SoundStream model
sound_stream = SoundStream(
sampling_rate=24000, generator_params={“hidden_dim”: 128, …}, # Fill with actual params discriminator_params={“scales”: 3, …}, # Fill with actual params
)
Perform forward pass with audio input
output = sound_stream.forward(audio_input)
######## NOTE The model is designed to be used in a training loop where the generator and discriminator are optimized iteratively.
Intialize SoundStream model.
- Parameters:TODO (jiatong)
decode(x: Tensor, **kwargs) → Tensor
Run decoding.
This method takes encoded input codes and generates a waveform tensor by passing the codes through the generator’s decoder.
- Parameters:x (Tensor) – Input codes (T_code, N_stream).
- Returns: Generated waveform (T_wav,).
- Return type: Tensor
############### Examples
>>> model = SoundStream()
>>> codes = torch.randn(100, 8) # Example input codes
>>> waveform = model.decode(codes)
>>> print(waveform.shape)
torch.Size([T_wav,]) # Shape of the generated waveform tensor
######## NOTE Ensure that the input tensor ‘x’ has the correct shape and data type expected by the generator’s decoder.
encode(x: Tensor, **kwargs) → Tensor
Run encoding.
This method processes the input audio tensor through the generator’s encoder to produce a set of neural codes.
- Parameters:x (Tensor) – Input audio tensor of shape (T_wav,).
- Returns: Generated codes of shape (T_code, N_stream), where T_code is the length of the generated codes and N_stream is the number of quantization streams.
- Return type: Tensor
############### Examples
>>> model = SoundStream()
>>> input_audio = torch.randn(1, 24000) # Example audio tensor
>>> codes = model.encode(input_audio)
>>> print(codes.shape) # Output shape will be (T_code, N_stream)
forward(audio: Tensor, forward_generator: bool = True, **kwargs) → Dict[str, Any]
Perform generator forward.
This method computes the forward pass for either the generator or the discriminator, depending on the forward_generator flag.
- Parameters:
- audio (Tensor) – Audio waveform tensor of shape (B, T_wav).
- forward_generator (bool) – Flag indicating whether to forward the generator (True) or the discriminator (False).
- Returns: A dictionary containing the following keys: : - loss (Tensor): Scalar tensor representing the total loss.
- stats (Dict[str, float]): Statistics to be monitored, including various loss components.
- weight (Tensor): Weight tensor summarizing the losses.
- optim_idx (int): Index indicating which optimizer to use (0 for generator and 1 for discriminator).
- Return type: Dict[str, Any]
############### Examples
>>> audio_input = torch.randn(8, 24000) # Batch of 8 audio samples
>>> model = SoundStream()
>>> output = model.forward(audio_input, forward_generator=True)
>>> print(output.keys())
dict_keys(['loss', 'stats', 'weight', 'optim_idx'])
######## NOTE The audio input should be pre-processed as necessary to fit the expected input shape.
inference(x: Tensor, **kwargs) → Dict[str, Tensor]
Run inference on input audio.
This method takes an input audio tensor, encodes it to a neural codec, and then decodes it back to a waveform. It is designed to be used for generating audio samples after the model has been trained.
- Parameters:x (Tensor) – Input audio tensor of shape (T_wav,).
- Returns:
- wav (Tensor): Generated waveform tensor of shape (T_wav,).
- codec (Tensor): Generated neural codec of shape (T_code, N_stream).
- Return type: Dict[str, Tensor]
############### Examples
>>> model = SoundStream()
>>> input_audio = torch.randn(24000) # Example input (1 second of audio)
>>> output = model.inference(input_audio)
>>> generated_wav = output['wav']
>>> generated_codec = output['codec']
######## NOTE The input audio tensor should be of shape (T_wav,) where T_wav is the length of the audio signal. The output includes both the generated waveform and the codec representation.
meta_info() → Dict[str, Any]
Retrieve meta information about the SoundStream model.
This method provides key details about the model’s configuration, including the sampling rate, number of streams, frame shift, and code size per stream.
- Returns: A dictionary containing the following key-value pairs: : - fs (int): The sampling rate of the model.
- num_streams (int): The number of quantization streams.
- frame_shift (int): The frame shift size used in processing.
- code_size_per_stream (List[int]): A list indicating the code size for each stream.
- Return type: Dict[str, Any]
############### Examples
>>> sound_stream = SoundStream()
>>> meta_info = sound_stream.meta_info()
>>> print(meta_info)
{'fs': 24000, 'num_streams': 8, 'frame_shift': 128,
'code_size_per_stream': [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]}