espnet2.asr_transducer.decoder.rwkv_decoder.RWKVDecoder
espnet2.asr_transducer.decoder.rwkv_decoder.RWKVDecoder
class espnet2.asr_transducer.decoder.rwkv_decoder.RWKVDecoder(vocab_size: int, block_size: int = 512, context_size: int = 1024, linear_size: int | None = None, attention_size: int | None = None, normalization_type: str = 'layer_norm', normalization_args: Dict = {}, num_blocks: int = 4, rescale_every: int = 0, embed_dropout_rate: float = 0.0, att_dropout_rate: float = 0.0, ffn_dropout_rate: float = 0.0, embed_pad: int = 0)
Bases: AbsDecoder
RWKV decoder module for Transducer models.
This class implements the RWKV decoder based on the architecture described in the paper: https://arxiv.org/pdf/2305.13048.pdf. It is designed to work with Transducer models for automatic speech recognition tasks.
block_size
The size of the input/output blocks.
- Type: int
attention_size
The hidden size for self-attention layers.
- Type: int
output_size
The size of the output layer.
- Type: int
vocab_size
The number of unique tokens in the vocabulary.
- Type: int
context_size
The size of the context for WKV computation.
- Type: int
rescale_every
Frequency of input rescaling in inference mode.
- Type: int
rescaled_layers
Flag indicating if layers are rescaled.
- Type: bool
pad_idx
The ID for the padding symbol in embeddings.
- Type: int
num_blocks
The number of RWKV blocks in the decoder.
- Type: int
score_cache
Cache for storing scores during decoding.
- Type: dict
device
The device on which the model is located.
Type: torch.device
Parameters:
- vocab_size (int) – Vocabulary size.
- block_size (int , optional) – Input/Output size. Default is 512.
- context_size (int , optional) – Context size for WKV computation. Default is 1024.
- linear_size (int , optional) – FeedForward hidden size. Default is None.
- attention_size (int , optional) – SelfAttention hidden size. Default is None.
- normalization_type (str , optional) – Normalization layer type. Default is “layer_norm”.
- normalization_args (Dict , optional) – Normalization layer arguments. Default is {}.
- num_blocks (int , optional) – Number of RWKV blocks. Default is 4.
- rescale_every (int , optional) – Rescale input every N blocks (inference only). Default is 0.
- embed_dropout_rate (float , optional) – Dropout rate for embedding layer. Default is 0.0.
- att_dropout_rate (float , optional) – Dropout rate for the attention module. Default is 0.0.
- ffn_dropout_rate (float , optional) – Dropout rate for the feed-forward module. Default is 0.0.
- embed_pad (int , optional) – Embedding padding symbol ID. Default is 0.
##################### Examples
Initialize the RWKVDecoder
decoder = RWKVDecoder(vocab_size=1000, block_size=512)
Forward pass through the decoder
labels = torch.randint(0, 1000, (32, 10)) # Example input output = decoder(labels)
Inference with hidden states
states = decoder.init_state(batch_size=32) output, new_states = decoder.inference(labels, states)
- Raises:AssertionError – If the length of the input labels exceeds the context size.
######### NOTE This implementation uses PyTorch and requires the appropriate environment with CUDA support for GPU acceleration if needed.
Construct a RWKVDecoder object.
#
batch_score(hyps
One-step forward hypotheses.
This method processes a batch of hypotheses and computes the decoder’s output for each hypothesis. It takes the last label from each hypothesis and uses the decoder’s inference method to generate the output and update the hidden states.
Parameters:hyps – A list of Hypothesis objects representing the current hypotheses. Each Hypothesis contains a label sequence and decoder state.
Returns: The decoder output sequence. Shape is (B, D_dec), where B is the : batch size and D_dec is the dimension of the decoder output.
states: The updated decoder hidden states. Shape is [5 x (B, 1, : D_att/D_dec, N)], where B is the batch size, D_att is the attention dimension, and N is the number of blocks.
Return type: out
##################### Examples
>>> decoder = RWKVDecoder(vocab_size=1000)
>>> hypotheses = [Hypothesis(yseq=[1, 2, 3], dec_state=initial_state)]
>>> output, states = decoder.batch_score(hyps=hypotheses)
>>> print(output.shape) # Should output (1, D_dec)
######### NOTE Ensure that the create_batch_states method is compatible with the structure of the decoder hidden states expected in the inference process.
#
create_batch_states(new_states
Create batch of decoder hidden states given a list of new states.
This method takes a list of new states for each hypothesis in the batch and combines them into a single batch of hidden states. The resulting hidden states can be used for further processing in the decoder.
- Parameters:new_states – A list of new decoder hidden states, where each entry corresponds to a hypothesis and is structured as: [B x [5 x (1, 1, D_att/D_dec, N)]].
- Returns: [5 x (B, 1, D_att/D_dec, N)], where B is the batch size.
- Return type: A list of decoder hidden states, structured as
##################### Examples
>>> new_states = [
... [torch.randn(1, 1, 128, 4) for _ in range(5)], # Hypothesis 1
... [torch.randn(1, 1, 128, 4) for _ in range(5)], # Hypothesis 2
... ]
>>> batch_states = create_batch_states(new_states)
>>> len(batch_states)
5
>>> batch_states[0].shape
torch.Size([2, 1, 128, 4]) # 2 hypotheses in the batch
#
forward(labels
RWKV decoder module.
Based on https://arxiv.org/pdf/2305.13048.pdf.
block_size
Size of the input/output.
- Type: int
attention_size
Size of the hidden layer in the attention module.
- Type: int
output_size
Size of the output layer.
- Type: int
vocab_size
Vocabulary size.
- Type: int
context_size
Context size for WKV computation.
- Type: int
rescale_every
Rescale input every N blocks (inference only).
- Type: int
rescaled_layers
Indicates if layers are rescaled.
- Type: bool
pad_idx
Embedding padding symbol ID.
- Type: int
num_blocks
Number of RWKV blocks.
- Type: int
score_cache
Cache for scores.
- Type: dict
device
The device on which the model is located.
Type: torch.device
Parameters:
- vocab_size (int) – Vocabulary size.
- block_size (int , optional) – Input/Output size. Default is 512.
- context_size (int , optional) – Context size for WKV computation. Default is 1024.
- linear_size (int , optional) – FeedForward hidden size. Default is None.
- attention_size (int , optional) – SelfAttention hidden size. Default is None.
- normalization_type (str , optional) – Normalization layer type. Default is “layer_norm”.
- normalization_args (dict , optional) – Normalization layer arguments. Default is {}.
- num_blocks (int , optional) – Number of RWKV blocks. Default is 4.
- rescale_every (int , optional) – Rescale input every N blocks (inference only). Default is 0.
- embed_dropout_rate (float , optional) – Dropout rate for embedding layer. Default is 0.0.
- att_dropout_rate (float , optional) – Dropout rate for the attention module. Default is 0.0.
- ffn_dropout_rate (float , optional) – Dropout rate for the feed-forward module. Default is 0.0.
- embed_pad (int , optional) – Embedding padding symbol ID. Default is 0.
forward(labels
torch.Tensor) -> torch.Tensor: Encode source label sequences.
inference(labels
torch.Tensor, states: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: Encode source label sequences with hidden states.
set_device(device
torch.device) -> None: Set GPU device to use.
score(label_sequence
List[int], states: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]: One-step forward hypothesis.
batch_score(hyps
List[Hypothesis]) -> Tuple[torch.Tensor, List[torch.Tensor]]: One-step forward hypotheses for a batch.
init_state(batch_size
int = 1) -> List[torch.Tensor]: Initialize RWKVDecoder states.
select_state(states
List[torch.Tensor], idx: int) -> List[torch.Tensor]: Select ID state from batch of decoder hidden states.
create_batch_states(new_states
List[List[Dict[str, torch.Tensor]]]) -> List[torch.Tensor]: Create batch of decoder hidden states given a list of new states.
##################### Examples
Create an instance of RWKVDecoder
decoder = RWKVDecoder(vocab_size=1000, block_size=512)
Forward pass with dummy labels
labels = torch.randint(0, 1000, (32, 20)) # (B, L) output = decoder.forward(labels) print(output.shape) # Should print: (32, 20, 512)
######### NOTE Ensure that the input length does not exceed the context size.
#
inference(labels
RWKV decoder definition for Transducer models.
This class implements the RWKV decoder module as described in the paper “RWKV: Reinventing RNNs for the Transformer Era” (https://arxiv.org/pdf/2305.13048.pdf). The decoder utilizes multiple RWKV blocks to process input sequences and produce output sequences with attention mechanisms.
vocab_size
The size of the vocabulary.
- Type: int
block_size
The input/output size for the RWKV blocks.
- Type: int
context_size
The context size used for WKV computation.
- Type: int
linear_size
The hidden size for the FeedForward layer.
- Type: Optional[int]
attention_size
The hidden size for the SelfAttention layer.
- Type: Optional[int]
normalization_type
The type of normalization layer to use.
- Type: str
normalization_args
Arguments for the normalization layer.
- Type: Dict
num_blocks
The number of RWKV blocks in the decoder.
- Type: int
rescale_every
Rescaling factor for input every N blocks during inference.
- Type: int
embed_dropout_rate
Dropout rate for the embedding layer.
- Type: float
att_dropout_rate
Dropout rate for the attention module.
- Type: float
ffn_dropout_rate
Dropout rate for the feed-forward module.
- Type: float
embed_pad
The padding symbol ID for the embedding.
Type: int
Parameters:
- vocab_size (int) – Vocabulary size.
- block_size (int) – Input/Output size.
- context_size (int) – Context size for WKV computation.
- linear_size (Optional *[*int ]) – FeedForward hidden size.
- attention_size (Optional *[*int ]) – SelfAttention hidden size.
- normalization_type (str) – Normalization layer type.
- normalization_args (Dict) – Normalization layer arguments.
- num_blocks (int) – Number of RWKV blocks.
- rescale_every (int) – Rescale input every N blocks (inference only).
- embed_dropout_rate (float) – Dropout rate for embedding layer.
- att_dropout_rate (float) – Dropout rate for the attention module.
- ffn_dropout_rate (float) – Dropout rate for the feed-forward module.
- embed_pad (int) – Embedding padding symbol ID.
##################### Examples
Creating an instance of the RWKVDecoder
decoder = RWKVDecoder(
vocab_size=1000, block_size=512, context_size=1024, num_blocks=4
)
Forward pass with labels
labels = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long) output = decoder(labels)
Performing inference
states = decoder.init_state(batch_size=2) output, new_states = decoder.inference(labels, states)
- Raises:AssertionError – If the length of the input labels exceeds the context size.
#
init_state(batch_size
RWKV decoder definition for Transducer models.
This module implements the RWKV decoder as described in the paper “RWKV: Reinventing RNNs for the Transformer Era” (https://arxiv.org/pdf/2305.13048.pdf).
The RWKVDecoder class provides methods for initializing and managing the state of the decoder, processing input sequences, and generating output sequences through inference.
vocab_size
Vocabulary size.
block_size
Input/Output size.
context_size
Context size for WKV computation.
linear_size
FeedForward hidden size.
attention_size
SelfAttention hidden size.
normalization_type
Normalization layer type.
normalization_args
Normalization layer arguments.
num_blocks
Number of RWKV blocks.
rescale_every
Whether to rescale input every N blocks (inference only).
embed_dropout_rate
Dropout rate for embedding layer.
att_dropout_rate
Dropout rate for the attention module.
ffn_dropout_rate
Dropout rate for the feed-forward module.
embed_pad
Embedding padding symbol ID.
- Parameters:
- vocab_size – Vocabulary size.
- block_size – Input/Output size.
- context_size – Context size for WKV computation.
- linear_size – FeedForward hidden size.
- attention_size – SelfAttention hidden size.
- normalization_type – Normalization layer type.
- normalization_args – Normalization layer arguments.
- num_blocks – Number of RWKV blocks.
- rescale_every – Whether to rescale input every N blocks (inference only).
- embed_dropout_rate – Dropout rate for embedding layer.
- att_dropout_rate – Dropout rate for the attention module.
- ffn_dropout_rate – Dropout rate for the feed-forward module.
- embed_pad – Embedding padding symbol ID.
##################### Examples
decoder = RWKVDecoder( : vocab_size=1000, block_size=512, context_size=1024, linear_size=2048, attention_size=512, num_blocks=4
) states = decoder.init_state(batch_size=2)
######### NOTE The init_state method initializes the decoder’s hidden states for a specified batch size. The hidden states consist of a list of tensors representing the state of each RWKV block.
#
score(label_sequence
RWKV decoder module.
Based on https://arxiv.org/pdf/2305.13048.pdf.
block_size
Input/Output size.
- Type: int
attention_size
SelfAttention hidden size.
- Type: int
output_size
Output size.
- Type: int
vocab_size
Vocabulary size.
- Type: int
context_size
Context size for WKV computation.
- Type: int
rescale_every
Whether to rescale input every N blocks (inference only).
- Type: int
pad_idx
Embedding padding symbol ID.
- Type: int
num_blocks
Number of RWKV blocks.
- Type: int
score_cache
Cache for scores.
- Type: dict
device
Device on which the model is located.
Type: torch.device
Parameters:
- vocab_size – Vocabulary size.
- block_size – Input/Output size.
- context_size – Context size for WKV computation.
- linear_size – FeedForward hidden size.
- attention_size – SelfAttention hidden size.
- normalization_type – Normalization layer type.
- normalization_args – Normalization layer arguments.
- num_blocks – Number of RWKV blocks.
- rescale_every – Whether to rescale input every N blocks (inference only).
- embed_dropout_rate – Dropout rate for embedding layer.
- att_dropout_rate – Dropout rate for the attention module.
- ffn_dropout_rate – Dropout rate for the feed-forward module.
- embed_pad – Embedding padding symbol ID.
##################### Examples
decoder = RWKVDecoder(vocab_size=10000, block_size=512) labels = torch.tensor([[1, 2, 3], [4, 5, 6]]) output = decoder(labels)
- Raises:AssertionError – If the input length exceeds the context size.
#
select_state(states
Select ID state from batch of decoder hidden states.
This method extracts the hidden states for a specific index from a batch of decoder hidden states. The hidden states are represented as a list of tensors, where each tensor corresponds to a different aspect of the state.
- Parameters:
- states – Decoder hidden states. A list of tensors with shape [5 x (B, 1, D_att/D_dec, N)], where B is the batch size, D_att is the attention dimension, D_dec is the decoder dimension, and N is the number of blocks.
- idx – The index of the state to select from the batch.
- Returns: A list of tensors representing the decoder hidden states for the specified index. The shape of each tensor is [1, 1, D_att/D_dec, N].
##################### Examples
>>> states = [
... torch.randn(5, 2, 1, 128, 4), # Example hidden states for 5 aspects
... torch.randn(5, 2, 1, 128, 4),
... torch.randn(5, 2, 1, 128, 4),
... torch.randn(5, 2, 1, 128, 4),
... torch.randn(5, 2, 1, 128, 4),
... ]
>>> idx = 0
>>> selected_state = select_state(states, idx)
>>> len(selected_state)
5
>>> selected_state[0].shape
torch.Size([1, 1, 128, 4])
#
set_device(device
Set GPU device to use.
This method allows you to specify the device on which the decoder will operate. It is particularly useful for transferring the model to a different GPU or CPU.
- Parameters:device – The device to set (e.g., torch.device(‘cuda:0’) or torch.device(‘cpu’)).
##################### Examples
>>> decoder = RWKVDecoder(vocab_size=1000)
>>> decoder.set_device(torch.device('cuda:0'))
######### NOTE Make sure that the device is available and compatible with the current model parameters.
- Raises:ValueError – If the specified device is not valid.