espnet2.gan_codec.shared.quantizer.modules.core_vq.sample_vectors
Less than 1 minute
espnet2.gan_codec.shared.quantizer.modules.core_vq.sample_vectors
espnet2.gan_codec.shared.quantizer.modules.core_vq.sample_vectors(samples, num: int)
Samples a specified number of vectors from the given samples.
This function selects num random vectors from the input tensor samples. If the number of available samples is greater than or equal to num, it selects random indices without replacement. If there are fewer available samples than num, it selects indices with replacement.
- Parameters:
- samples (Tensor) – A tensor of shape (N, D) where N is the number of samples and D is the dimensionality of each sample.
- num (int) – The number of vectors to sample from the input tensor.
- Returns: A tensor containing the sampled vectors of shape (num, D).
- Return type: Tensor
Examples
>>> samples = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
>>> sampled = sample_vectors(samples, 2)
>>> print(sampled)
tensor([[3.0, 4.0], [1.0, 2.0]]) # Output may vary due to randomness.
>>> sampled = sample_vectors(samples, 5)
>>> print(sampled)
tensor([[1.0, 2.0], [1.0, 2.0], [5.0, 6.0], [3.0, 4.0], [5.0, 6.0]])
# Output may vary due to randomness.