espnet2.enh.layers.adapt_layers.ConcatAdaptLayer
espnet2.enh.layers.adapt_layers.ConcatAdaptLayer
class espnet2.enh.layers.adapt_layers.ConcatAdaptLayer(indim, enrolldim, ninputs=1)
Bases: Module
ConcatAdaptLayer is a PyTorch module that adapts input activations using
enrollment embeddings by concatenating them. It is useful for scenarios where both normal and skip connections need adaptation.
ninputs
The number of input tensors to adapt.
- Type: int
transform
A list of linear transformation layers corresponding to the number of inputs.
Type: nn.ModuleList
Parameters:
- indim (int) – The dimensionality of the input activations.
- enrolldim (int) – The dimensionality of the enrollment embeddings.
- ninputs (int , optional) – The number of inputs to adapt. Defaults to 1.
forward(main, enroll)
Performs the forward pass of the layer.
- Returns: The adapted activations after applying the transformation.
- Return type: torch.Tensor or tuple or list
- Raises:
- AssertionError – If the types of main and enroll do not match, or if the
- lengths of main and enroll do not match ninputs. –
Examples
>>> model = ConcatAdaptLayer(indim=128, enrolldim=64, ninputs=2)
>>> main_input = (torch.randn(10, 128), torch.randn(10, 128))
>>> enroll_input = (torch.randn(10, 64), torch.randn(10, 64))
>>> output = model(main_input, enroll_input)
>>> output[0].shape
torch.Size([10, 128])
>>> output[1].shape
torch.Size([10, 128])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
forward(main, enroll)
Initializes the ConcatAdaptLayer.
- Parameters:
- indim – int The input dimension for the main activations.
- enrolldim – int The dimension of the enrollment embeddings.
- ninputs – int, optional The number of input tensors (default is 1).
ninputs
int The number of input tensors.
transform
nn.ModuleList A list of linear transformation layers for adapting the main activations with the enrollment embeddings.