espnet2.legacy.nets.pytorch_backend.wavenet.WaveNet
Less than 1 minute
espnet2.legacy.nets.pytorch_backend.wavenet.WaveNet
class espnet2.legacy.nets.pytorch_backend.wavenet.WaveNet(n_quantize=256, n_aux=28, n_resch=512, n_skipch=256, dilation_depth=10, dilation_repeat=3, kernel_size=2, upsampling_factor=0)
Bases: Module
Conditional wavenet.
- Parameters:
- n_quantize (int) β Number of quantization.
- n_aux (int) β Number of aux feature dimension.
- n_resch (int) β Number of filter channels for residual block.
- n_skipch (int) β Number of filter channels for skip connection.
- dilation_depth (int) β Number of dilation depth (e.g. if set 10, max dilation = 2^(10-1)).
- dilation_repeat (int) β Number of dilation repeat.
- kernel_size (int) β Filter size of dilated causal convolution.
- upsampling_factor (int) β Upsampling factor.
Initialize WaveNet class.
forward(x, h)
Calculate forward propagation.
- Parameters:
- x (LongTensor) β Quantized input waveform tensor with the shape (B, T).
- h (Tensor) β Auxiliary feature tensor with the shape (B, n_aux, T).
- Returns: Logits with the shape (B, T, n_quantize).
- Return type: Tensor
generate(x, h, n_samples, interval=None, mode='sampling')
Generate a waveform with fast genration algorithm.
This generation based on Fast WaveNet Generation Algorithm.
- Parameters:
- x (LongTensor) β Initial waveform tensor with the shape (T,).
- h (Tensor) β Auxiliary feature tensor with the shape (n_samples + T, n_aux).
- n_samples (int) β Number of samples to be generated.
- interval (int , optional) β Log interval.
- mode (str , optional) β βsamplingβ or βargmaxβ.
- Returns: Generated quantized waveform (n_samples).
- Return type: ndarray
