espnet2.asr.encoder.avhubert_encoder.ResEncoder
espnet2.asr.encoder.avhubert_encoder.ResEncoder
class espnet2.asr.encoder.avhubert_encoder.ResEncoder(relu_type, weights)
Bases: Module
3D ResNet-based Encoder for audio-visual tasks.
This class implements a 3D convolutional neural network that processes input data in the form of videos, extracting features using a ResNet architecture. It consists of a frontend for initial feature extraction and a trunk that applies residual connections to improve learning.
frontend_nout
Number of output channels for the frontend.
- Type: int
backend_out
Number of output channels for the trunk.
- Type: int
frontend3D
Sequential model for the frontend processing.
- Type: nn.Sequential
trunk
ResNet model for feature extraction from 2D tensor input.
Type:ResNet
Parameters:
- relu_type (str) – Type of ReLU activation to use (‘relu’ or ‘prelu’).
- weights (Optional *[*str ]) – Path to pre-trained weights for the model.
######### Examples
>>> model = ResEncoder(relu_type='relu', weights=None)
>>> input_tensor = torch.randn(8, 1, 10, 112, 112) # (B, C, T, H, W)
>>> output = model(input_tensor)
>>> output.shape
torch.Size([8, 512, 10]) # (B, D, T)
####### NOTE The input tensor must be in the shape of (B, C, T, H, W), where B is the batch size, C is the number of channels, T is the number of frames, and H and W are the height and width of the frames.
- Raises:
- RuntimeError – If there is an issue loading the weights or
- during the forward pass due to incompatible input shapes. –
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Forward pass through the AVHubert Encoder.
This method processes the input tensors for audio and video, applying necessary masking and dropout techniques as configured. It returns the output features along with the lengths of the valid output sequences.
- Parameters:
- xs_pad (Dict *[*str , torch.Tensor ]) – A dictionary containing input tensors:
- “video”: input tensor of shape (B, 1, L, H, W) for video data.
- “audio”: input tensor of shape (B, D, L) for audio data.
- ilens (torch.Tensor) – A tensor containing the lengths of the input sequences (B).
- prev_states (torch.Tensor , optional) – Not used in the current implementation. Defaults to None.
- xs_pad (Dict *[*str , torch.Tensor ]) – A dictionary containing input tensors:
- Returns:
- Output tensor of shape (B, T, D) containing the features after encoding.
- A tensor of shape (B) representing the lengths of valid output sequences.
- None (as placeholder for future use).
- Return type: Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]
- Raises:ValueError – If neither “video” nor “audio” keys are present in xs_pad.
######### Examples
>>> audio_input = torch.randn(2, 104, 10) # (B, D, L)
>>> video_input = torch.randn(2, 1, 10, 224, 224) # (B, 1, L, H, W)
>>> ilens = torch.tensor([10, 10]) # lengths for both inputs
>>> xs_pad = {"audio": audio_input, "video": video_input}
>>> output, lengths, _ = model.forward(xs_pad, ilens)
>>> print(output.shape) # Output shape: (2, T, D)
>>> print(lengths) # Lengths of the valid output sequences
####### NOTE Ensure that the input tensors are properly padded before passing them to the forward method to avoid dimension mismatches.
threeD_to_2D_tensor(x)
Reshape a 3D tensor into a 2D tensor for processing.
This method takes a 5-dimensional tensor (batch, channels, time, height, width) and reshapes it into a 4-dimensional tensor (batch*time, channels, height, width). This transformation is useful for passing the data through a 2D convolutional network after extracting features from a 3D input.
- Parameters:x (torch.Tensor) – A tensor of shape (n_batch, n_channels, s_time, sx, sy).
- Returns: A reshaped tensor of shape (n_batch * s_time, n_channels, sx, sy).
- Return type: torch.Tensor
######### Examples
>>> import torch
>>> x = torch.randn(2, 3, 4, 5, 6) # Example input tensor
>>> reshaped_x = self.threeD_to_2D_tensor(x)
>>> reshaped_x.shape
torch.Size([8, 3, 5, 6]) # (2 * 4, 3, 5, 6)
####### NOTE The input tensor must have 5 dimensions; otherwise, an error will occur.
- Raises:ValueError – If the input tensor does not have 5 dimensions.