espnet2.gan_codec.soundstream.soundstream.SoundStreamGenerator
espnet2.gan_codec.soundstream.soundstream.SoundStreamGenerator
class espnet2.gan_codec.soundstream.soundstream.SoundStreamGenerator(sample_rate: int = 24000, hidden_dim: int = 128, encdec_channels: int = 1, encdec_n_filters: int = 32, encdec_n_residual_layers: int = 1, encdec_ratios: List[int] = [8, 5, 4, 2], encdec_activation: str = 'ELU', encdec_activation_params: Dict[str, Any] = {'alpha': 1.0}, encdec_norm: str = 'weight_norm', encdec_norm_params: Dict[str, Any] = {}, encdec_kernel_size: int = 7, encdec_residual_kernel_size: int = 7, encdec_last_kernel_size: int = 7, encdec_dilation_base: int = 2, encdec_causal: bool = False, encdec_pad_mode: str = 'reflect', encdec_true_skip: bool = False, encdec_compress: int = 2, encdec_lstm: int = 2, decoder_trim_right_ratio: float = 1.0, decoder_final_activation: str | None = None, decoder_final_activation_params: dict | None = None, quantizer_n_q: int = 8, quantizer_bins: int = 1024, quantizer_decay: float = 0.99, quantizer_kmeans_init: bool = True, quantizer_kmeans_iters: int = 50, quantizer_threshold_ema_dead_code: int = 2, quantizer_target_bandwidth: List[float] = [7.5, 15])
Bases: Module
SoundStream generator module.
This module implements the generator part of the SoundStream model, which is responsible for encoding and decoding audio signals. The generator utilizes a neural network architecture consisting of an encoder, quantizer, and decoder to perform the audio processing tasks.
encoder
The encoder module that processes the input audio.
- Type:SEANetEncoder
quantizer
The quantization module that encodes the features into a discrete representation.
target_bandwidths
List of target bandwidths for the quantizer.
- Type: List[float]
sample_rate
The sample rate of the audio signals.
- Type: int
frame_rate
The frame rate calculated from the sample rate and encoding ratios.
- Type: int
decoder
The decoder module that reconstructs the audio from the quantized representation.
- Type:SEANetDecoder
l1_quantization_loss
Loss function for L1 quantization loss.
- Type: torch.nn.L1Loss
l2_quantization_loss
Loss function for L2 quantization loss.
Type: torch.nn.MSELoss
Parameters:
- sample_rate (int) – The sample rate of the audio (default: 24000).
- hidden_dim (int) – The dimension of hidden layers (default: 128).
- encdec_channels (int) – Number of channels for encoder/decoder (default: 1).
- encdec_n_filters (int) – Number of filters in encoder/decoder (default: 32).
- encdec_n_residual_layers (int) – Number of residual layers (default: 1).
- encdec_ratios (List *[*int ]) – Ratios for the encoder/decoder (default: [8, 5, 4, 2]).
- encdec_activation (str) – Activation function used (default: “ELU”).
- encdec_activation_params (Dict *[*str , Any ]) – Parameters for activation function (default: {“alpha”: 1.0}).
- encdec_norm (str) – Normalization type (default: “weight_norm”).
- encdec_norm_params (Dict *[*str , Any ]) – Parameters for normalization (default: {}).
- encdec_kernel_size (int) – Kernel size for the encoder/decoder (default: 7).
- encdec_residual_kernel_size (int) – Kernel size for residual layers (default: 7).
- encdec_last_kernel_size (int) – Kernel size for the last layer (default: 7).
- encdec_dilation_base (int) – Dilation base for the encoder/decoder (default: 2).
- encdec_causal (bool) – Whether to use causal convolution (default: False).
- encdec_pad_mode (str) – Padding mode for convolution (default: “reflect”).
- encdec_true_skip (bool) – Whether to use true skip connections (default: False).
- encdec_compress (int) – Compression factor (default: 2).
- encdec_lstm (int) – Number of LSTM layers (default: 2).
- decoder_trim_right_ratio (float) – Ratio for trimming the decoder output (default: 1.0).
- decoder_final_activation (Optional *[*str ]) – Final activation function (default: None).
- decoder_final_activation_params (Optional *[*dict ]) – Parameters for final activation (default: None).
- quantizer_n_q (int) – Number of quantization codes (default: 8).
- quantizer_bins (int) – Number of bins for quantization (default: 1024).
- quantizer_decay (float) – Decay factor for quantization (default: 0.99).
- quantizer_kmeans_init (bool) – Whether to initialize with k-means (default: True).
- quantizer_kmeans_iters (int) – Number of iterations for k-means (default: 50).
- quantizer_threshold_ema_dead_code (int) – Threshold for dead code (default: 2).
- quantizer_target_bandwidth (List *[*float ]) – Target bandwidths for quantization (default: [7.5, 15]).
Returns: This constructor does not return any value.
Return type: None
########### Examples
generator = SoundStreamGenerator(sample_rate=24000) output_audio, commit_loss, quantization_loss, resyn_audio_real = generator(
input_tensor
)
####### NOTE The input tensor for the forward method should have a shape of (B, 1, T), where B is the batch size and T is the length of the audio sequence.
Initialize SoundStream Generator.
- Parameters:TODO (jiatong)
decode(codes: Tensor)
Run decoding.
This method takes neural codes as input and generates a waveform.
- Parameters:
- x (Tensor) – Input codes (T_code, N_stream). The shape of the input
- audio (tensor should correspond to the encoded representations of the)
- signals.
- Returns: Generated waveform (T_wav,). This tensor contains the reconstructed audio signal derived from the input codes.
- Return type: Tensor
########### Examples
>>> model = SoundStream()
>>> codes = torch.randn(100, 8) # Example shape for codes
>>> waveform = model.decode(codes)
>>> print(waveform.shape) # Output shape will be (T_wav,)
encode(x: Tensor, target_bw: float | None = None)
Run encoding.
- Parameters:x (Tensor) – Input audio (T_wav,).
- Returns: Generated codes (T_code, N_stream).
- Return type: Tensor
########### Examples
>>> model = SoundStream(...)
>>> input_audio = torch.randn(1, 24000) # Simulated audio input
>>> codes = model.encode(input_audio)
>>> print(codes.shape) # Output shape should be (T_code, N_stream)
####### NOTE The input tensor should have a shape of (B, T_wav) where B is the batch size and T_wav is the length of the audio waveform. The output will be the encoded representation of the audio in terms of codes.
forward(x: Tensor, use_dual_decoder: bool = False)
Perform generator forward.
- Parameters:
- audio (Tensor) – Audio waveform tensor (B, T_wav).
- forward_generator (bool) – Whether to forward generator.
- Returns:
- loss (Tensor): Loss scalar tensor.
- stats (Dict[str, float]): Statistics to be monitored.
- weight (Tensor): Weight tensor to summarize losses.
- optim_idx (int): Optimizer index (0 for G and 1 for D).
- Return type: Dict[str, Any]
########### Examples
>>> audio_input = torch.randn(8, 24000) # Example audio input
>>> model = SoundStream()
>>> output = model.forward(audio_input, forward_generator=True)
>>> print(output['loss'].item()) # Accessing the loss value
####### NOTE This method performs a forward pass through the generator or discriminator based on the forward_generator flag. If forward_generator is set to True, it processes the audio input through the generator; otherwise, it forwards through the discriminator.