espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer
espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer
class espnet2.gan_codec.shared.quantizer.residual_vq.ResidualVectorQuantizer(dimension: int = 256, codebook_dim: int = 512, n_q: int = 8, bins: int = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, quantizer_dropout: bool = False)
Bases: Module
Residual Vector Quantizer.
This class implements a residual vector quantization mechanism which is designed to quantize input tensors while minimizing information loss. It utilizes multiple quantizers to achieve better performance and supports initialization via k-means.
n_q
Number of residual vector quantizers used.
- Type: int
dimension
Dimension of the codebooks.
- Type: int
codebook_dim
Dimension of the codebook vectors.
- Type: int
bins
Codebook size.
- Type: int
decay
Decay for exponential moving average over the codebooks.
- Type: float
kmeans_init
Whether to use kmeans to initialize the codebooks.
- Type: bool
kmeans_iters
Number of iterations used for kmeans initialization.
- Type: int
threshold_ema_dead_code
Threshold for dead code expiration.
- Type: int
quantizer_dropout
Flag to indicate whether to apply dropout.
Type: bool
Parameters:
- dimension (int) – Dimension of the codebooks.
- n_q (int) – Number of residual vector quantizers used.
- bins (int) – Codebook size.
- decay (float) – Decay for exponential moving average over the codebooks.
- kmeans_init (bool) – Whether to use kmeans to initialize the codebooks.
- kmeans_iters (int) – Number of iterations used for kmeans initialization.
- threshold_ema_dead_code (int) – Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with a randomly selected vector from the current batch.
############### Examples
>>> rvq = ResidualVectorQuantizer(dimension=256, n_q=8, bins=1024)
>>> input_tensor = torch.randn(1, 256)
>>> sample_rate = 16000
>>> result = rvq(input_tensor, sample_rate)
>>> quantized_output = result.quantized
>>> codes = result.codes
Initialize internal Module state, shared by both nn.Module and ScriptModule.
decode(codes: Tensor) → Tensor
Decode the given codes to the quantized representation.
This method takes the quantization codes produced by the encoding process and converts them back into the quantized tensor representation.
- Parameters:codes (torch.Tensor) – A tensor containing the quantization codes to be decoded.
- Returns: The decoded quantized representation corresponding to : the provided codes.
- Return type: torch.Tensor
############### Examples
>>> rvq = ResidualVectorQuantizer()
>>> codes = torch.tensor([[1, 2], [3, 4]])
>>> decoded_output = rvq.decode(codes)
>>> print(decoded_output.shape) # Output will depend on the quantizer setup
######## NOTE The input codes should be properly formatted as per the quantizer’s requirements for decoding to succeed.
Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizers to use and returns indices for each quantizer.
- Parameters:
- x (torch.Tensor) – Input tensor to be encoded.
- sample_rate (int) – Sample rate of the input tensor.
- bandwidth (Optional *[*float ]) – Target bandwidth. If specified, the number of quantizers will be adjusted based on this value.
- st (Optional *[*int ]) – Starting index for encoding. Defaults to 0.
- Returns: Indices for each quantizer after encoding.
- Return type: torch.Tensor
############### Examples
>>> rvq = ResidualVectorQuantizer()
>>> input_tensor = torch.randn(1, 256) # Example input
>>> sample_rate = 16000
>>> bandwidth = 128
>>> encoded_indices = rvq.encode(input_tensor, sample_rate, bandwidth)
######## NOTE Ensure that the input tensor has the correct dimensions for encoding.
forward(x: Tensor, sample_rate: int, bandwidth: float | None = None) → QuantizedResult
Residual vector quantization on the given input tensor.
This method performs residual vector quantization on the input tensor x and returns a QuantizedResult containing the quantized representation, associated bandwidth, and any penalty term for the loss. The number of quantizers used is determined based on the specified bandwidth and sample rate.
- Parameters:
- x (torch.Tensor) – Input tensor.
- sample_rate (int) – Sample rate of the input tensor.
- bandwidth (Optional *[*float ]) – Target bandwidth for quantization.
- Returns: A dataclass containing the quantized representation, codes, bandwidth, and an optional penalty term. The attributes are:
- quantized (torch.Tensor): The quantized representation.
- codes (torch.Tensor): The indices of the quantized codes.
- bandwidth (torch.Tensor): The bandwidth in kb/s used per
batch item.
- penalty (Optional[torch.Tensor]): An optional penalty term for the loss, if applicable.
- Return type:QuantizedResult
############### Examples
>>> model = ResidualVectorQuantizer()
>>> input_tensor = torch.randn(1, 256)
>>> sample_rate = 16000
>>> result = model.forward(input_tensor, sample_rate)
>>> print(result.quantized.shape) # Output shape of quantized tensor
>>> print(result.bandwidth) # Output bandwidth
######## NOTE If quantizer_dropout is enabled, the method may also return an additional quantization loss term.
get_bandwidth_per_quantizer(sample_rate: int)
Return bandwidth per quantizer for a given input sample rate.
This method calculates the bandwidth required for each quantizer based on the input sample rate. The bandwidth is computed using the formula: bandwidth = log2(bins) * sample_rate / 1000, where ‘bins’ refers to the codebook size.
- Parameters:sample_rate (int) – The sample rate of the input tensor.
- Returns: The bandwidth per quantizer in kilobits per second (kb/s).
- Return type: float
############### Examples
>>> rvq = ResidualVectorQuantizer()
>>> bw = rvq.get_bandwidth_per_quantizer(sample_rate=16000)
>>> print(bw)
128.0 # This value may vary depending on the bins and sample rate.
get_num_quantizers_for_bandwidth(sample_rate: int, bandwidth: float | None = None) → int
Return the number of quantizers (n_q) based on specified target bandwidth.
This method calculates the number of residual vector quantizers that can be utilized given a target bandwidth. It ensures that the number of quantizers does not exceed the available bandwidth divided by the bandwidth per quantizer.
- Parameters:
- sample_rate (int) – The sample rate of the input data.
- bandwidth (Optional *[*float ]) – The target bandwidth in kb/s. If None or less than or equal to zero, the maximum number of quantizers (n_q) is returned.
- Returns: The number of quantizers to use based on the specified bandwidth.
- Return type: int
############### Examples
>>> rvq = ResidualVectorQuantizer()
>>> rvq.get_num_quantizers_for_bandwidth(sample_rate=44100, bandwidth=64)
2
>>> rvq.get_num_quantizers_for_bandwidth(sample_rate=44100, bandwidth=None)
8
######## NOTE The bandwidth must be a positive value to calculate the number of quantizers. If the bandwidth is invalid, the method defaults to the maximum number of quantizers defined during initialization.