espnet2.gan_codec.shared.encoder.seanet_2d.ReshapeModule
espnet2.gan_codec.shared.encoder.seanet_2d.ReshapeModule
class espnet2.gan_codec.shared.encoder.seanet_2d.ReshapeModule(dim=2)
Bases: Module
Module to reshape tensors by removing specified dimensions.
This module allows the removal of a specific dimension from the input tensor, effectively reshaping it. This can be useful in various neural network architectures where dimensionality adjustments are required.
dim
The dimension to squeeze from the input tensor.
Type: int
Parameters:dim (int) – The dimension to remove from the input tensor. Default is 2.
Returns: The reshaped tensor with the specified dimension removed.
Return type: torch.Tensor
####### Examples
>>> reshape_module = ReshapeModule(dim=2)
>>> input_tensor = torch.rand(2, 3, 4) # Shape: (2, 3, 4)
>>> output_tensor = reshape_module(input_tensor)
>>> output_tensor.shape
torch.Size([2, 3]) # The shape after squeezing dimension 2
NOTE
The input tensor must have the specified dimension to be squeezed; otherwise, an error will occur. If the input tensor does not have the specified dimension, the original tensor will remain unchanged.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(x)
Module for reshaping tensors by squeezing specified dimensions.
This module removes a dimension of size 1 from the input tensor along the specified axis. This is particularly useful in scenarios where it is necessary to eliminate redundant dimensions in tensor representations, such as preparing inputs for further processing or outputs for loss calculations.
dim
The dimension to squeeze from the input tensor.
Type: int
Parameters:dim (int) – The dimension to be squeezed. Defaults to 2.
Returns: The input tensor with the specified dimension squeezed.
Return type: torch.Tensor
####### Examples
>>> reshape_module = ReshapeModule(dim=1)
>>> input_tensor = torch.tensor([[[1]], [[2]], [[3]]]) # shape (3, 1, 1)
>>> output_tensor = reshape_module(input_tensor) # shape (3,)
>>> print(output_tensor)
tensor([1, 2, 3])
NOTE
The input tensor must have the specified dimension with size 1. If the dimension specified does not have size 1, the output tensor will have the same shape as the input tensor.
- Raises:
- IndexError – If the specified dimension is out of bounds for the
- input tensor. –