espnet2.uasr.generator.conv_generator.TransposeLast
espnet2.uasr.generator.conv_generator.TransposeLast
class espnet2.uasr.generator.conv_generator.TransposeLast(deconstruct_idx=None)
Bases: Module
Transpose the last two dimensions of the input tensor.
This module is designed to facilitate the manipulation of tensor shapes in neural network architectures. It can also deconstruct the input tensor by selecting a specific index if deconstruct_idx is provided.
deconstruct_idx
The index to deconstruct the input tensor. If None, no deconstruction is performed.
Type: Optional[int]
Parameters:deconstruct_idx (Optional *[*int ]) – The index to select a specific part of the input tensor. Defaults to None.
Returns: The input tensor with its last two dimensions transposed.
Return type: torch.Tensor
####### Examples
>>> import torch
>>> transpose_last = TransposeLast()
>>> x = torch.randn(2, 3, 4)
>>> output = transpose_last(x)
>>> output.shape
torch.Size([2, 4, 3])
>>> transpose_last_deconstruct = TransposeLast(deconstruct_idx=1)
>>> x = torch.randn(2, 3, 4)
>>> output = transpose_last_deconstruct(x)
>>> output.shape
torch.Size([2, 4])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Forward pass for the convolutional generator.
This method processes the input features through a series of operations, including optional batch normalization, residual connections, and convolutional layers, to generate the output samples. It also computes the real sample if text input is provided.
- Parameters:
- feats (torch.Tensor) – Input feature tensor of shape (B, C, L), where B is the batch size, C is the number of channels, and L is the length of the input sequence.
- text (Optional *[*torch.Tensor ]) – Input tensor containing text indices. If provided, it is used to generate a real sample.
- feats_padding_mask (torch.Tensor) – Padding mask tensor of shape (B, L) indicating which features are valid (True) or padded (False).
- Returns: Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor]: A tuple containing:
- generated_sample (torch.Tensor): Output tensor after processing, shape (B, output_dim, new_length).
- real_sample (Optional[torch.Tensor]): Tensor representing the one-hot encoded real sample if text is provided, otherwise None.
- inter_x (Optional[torch.Tensor]): Intermediate tensor used in residual connection if applicable, otherwise None.
- generated_sample_padding_mask (torch.Tensor): Updated padding mask for the generated sample, shape (B, new_length).
- Raises:AssertionError – If the text tensor is provided but contains no non-zero elements.
####### Examples
>>> generator = ConvGenerator(input_dim=256, output_dim=512)
>>> feats = torch.randn(10, 256, 50)
>>> text = torch.tensor([[1, 0], [0, 1]])
>>> feats_padding_mask = torch.ones(10, 50, dtype=torch.bool)
>>> generated_sample, real_sample, inter_x, mask = generator.forward(
... feats, text, feats_padding_mask
... )
NOTE
This function is designed to work with batch processing of inputs. Ensure that the input tensors have the correct dimensions to avoid runtime errors.