espnet2.speechlm.module.valle.ResidualAttentionBlockAdaLM
espnet2.speechlm.module.valle.ResidualAttentionBlockAdaLM
class espnet2.speechlm.module.valle.ResidualAttentionBlockAdaLM(n_state: int, n_head: int, cross_attention: bool = False)
Bases: ResidualAttentionBlock
ResidualAttentionBlockAdaLM is a class that implements a residual attention
block for adaptive layer normalization in the context of language modeling. It extends the ResidualAttentionBlock class and utilizes AdaLN for normalization.
n_state
The dimensionality of the input and output states.
- Type: int
n_head
The number of attention heads.
- Type: int
cross_attention
A flag indicating whether to use cross-attention.
Type: bool
Parameters:
- n_state (int) – The dimensionality of the input and output states.
- n_head (int) – The number of attention heads.
- cross_attention (bool , optional) – A flag to enable cross-attention. Defaults to False.
forward(x
Tensor, level: Tensor, xa: Optional[Tensor] = None, : mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None) -> Tensor:
Computes the forward pass of the residual attention block.
- Raises:ValueError – If the input tensor dimensions do not match the expected dimensions.
####### Examples
>>> block = ResidualAttentionBlockAdaLM(n_state=512, n_head=8)
>>> x = torch.randn(10, 20, 512) # (batch_size, seq_len, n_state)
>>> level = torch.randint(0, 5, (10,)) # (batch_size,)
>>> output = block(x, level)
>>> print(output.shape)
torch.Size([10, 20, 512])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
#
forward(x
Executes the forward pass of the ResidualAttentionBlockAdaLM module.
This method takes an input tensor and performs attention and feed-forward operations, applying layer normalization and residual connections. The method can handle optional cross-attention if specified during initialization.
- Parameters:
- x (Tensor) – The input tensor of shape (batch_size, seq_length, n_state).
- level (Tensor) – The level embedding tensor of shape (batch_size,).
- xa (Optional *[*Tensor ] , optional) – The cross-attention input tensor of shape (batch_size, seq_length, n_state). Defaults to None.
- mask (Optional *[*Tensor ] , optional) – The attention mask tensor to prevent attending to certain positions. Defaults to None.
- kv_cache (Optional *[*dict ] , optional) – A cache dictionary for key-value pairs in attention. Defaults to None.
- Returns: The output tensor after applying attention and feed-forward operations, with the same shape as input tensor x.
- Return type: Tensor
####### Examples
>>> block = ResidualAttentionBlockAdaLM(n_state=512, n_head=8)
>>> x = torch.rand(10, 20, 512) # batch_size=10, seq_length=20
>>> level = torch.randint(0, 5, (10,)) # 5 levels
>>> output = block(x, level)
>>> output.shape
torch.Size([10, 20, 512])
NOTE
This method is part of the ResidualAttentionBlockAdaLM class, which is derived from ResidualAttentionBlock and implements additional features for attention mechanisms.