espnet2.asr_transducer.joint_network.JointNetwork
espnet2.asr_transducer.joint_network.JointNetwork
class espnet2.asr_transducer.joint_network.JointNetwork(output_size: int, encoder_size: int, decoder_size: int, joint_space_size: int = 256, joint_activation_type: str = 'tanh', lin_dec_bias: bool = True, **activation_parameters)
Bases: Module
Transducer joint network implementation.
This module implements a joint network for transducer models in automatic speech recognition (ASR). The JointNetwork class combines encoder and decoder outputs through a specified activation function to produce the final output.
lin_enc
Linear layer for encoder output.
- Type: torch.nn.Linear
lin_dec
Linear layer for decoder output.
- Type: torch.nn.Linear
lin_out
Linear layer for producing final output.
- Type: torch.nn.Linear
joint_activation
Activation function for the joint network.
Type: callable
Parameters:
- output_size (int) – Output size.
- encoder_size (int) – Encoder output size.
- decoder_size (int) – Decoder output size.
- joint_space_size (int , optional) – Joint space size (default is 256).
- joint_activation_type (str , optional) – Type of activation for joint network (default is “tanh”).
- lin_dec_bias (bool , optional) – Whether to include bias in the decoder linear layer (default is True).
- **activation_parameters – Additional parameters for the activation function.
####### Examples
>>> joint_network = JointNetwork(
... output_size=10,
... encoder_size=20,
... decoder_size=30,
... joint_space_size=256,
... joint_activation_type='relu'
... )
>>> enc_out = torch.randn(5, 10, 1, 20) # (B, T, s_range, D_enc)
>>> dec_out = torch.randn(5, 10, 1, 30) # (B, T, U, D_dec)
>>> output = joint_network(enc_out, dec_out)
>>> print(output.shape) # Should be (5, 10, U, 10)
- Raises:ValueError – If the shapes of enc_out and dec_out do not match expected dimensions.
NOTE
The input tensors enc_out and dec_out can have different shapes depending on the specific use case. The joint network computes their combined output through learned linear transformations followed by a specified activation.
Construct a JointNetwork object.
forward(enc_out: Tensor, dec_out: Tensor, no_projection: bool = False) → Tensor
Joint computation of encoder and decoder hidden state sequences.
This method performs a joint computation by combining the outputs from the encoder and decoder. It either applies a projection to the encoder and decoder outputs or directly computes the joint output if specified.
- Parameters:
- enc_out (torch.Tensor) – Expanded encoder output state sequences. Shape can be (B, T, s_range, D_enc) or (B, T, 1, D_enc).
- dec_out (torch.Tensor) – Expanded decoder output state sequences. Shape can be (B, T, s_range, D_dec) or (B, 1, U, D_dec).
- no_projection (bool , optional) – If True, skips the projection step. Defaults to False.
- Returns: Joint output state sequences. : Shape will be (B, T, U, D_out) or (B, T, s_range, D_out).
- Return type: torch.Tensor
####### Examples
>>> joint_network = JointNetwork(output_size=10, encoder_size=20,
... decoder_size=30)
>>> enc_out = torch.randn(5, 10, 1, 20) # Example encoder output
>>> dec_out = torch.randn(5, 10, 2, 30) # Example decoder output
>>> joint_output = joint_network(enc_out, dec_out)
>>> print(joint_output.shape)
torch.Size([5, 10, 2, 10]) # Output shape based on inputs
NOTE
Ensure that the shapes of enc_out and dec_out are compatible for addition when no_projection is set to False.