espnet2.asr_transducer.decoder.modules.rwkv.feed_forward.FeedForward
espnet2.asr_transducer.decoder.modules.rwkv.feed_forward.FeedForward
class espnet2.asr_transducer.decoder.modules.rwkv.feed_forward.FeedForward(size: int, hidden_size: int, block_id: int, num_blocks: int)
Bases: Module
Feed-forward (channel mixing) module for RWKV block.
Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
time_shift
Zero padding layer for time shifting.
- Type: torch.nn.ZeroPad2d
time_mix_key
Parameter for time mixing key.
- Type: torch.nn.Parameter
time_mix_receptance
Parameter for time mixing receptance.
- Type: torch.nn.Parameter
proj_key
Linear transformation for the key projection.
- Type: torch.nn.Linear
proj_value
Linear transformation for the value projection.
- Type: torch.nn.Linear
proj_receptance
Linear transformation for the receptance.
Type: torch.nn.Linear
Parameters:
- size (int) – Input/Output size.
- hidden_size (int) – Hidden size.
- block_id (int) – Block index.
- num_blocks (int) – Number of blocks in the architecture.
reset_parameters(size
int, block_id: int, num_blocks: int) -> None: Reset module parameters based on block size and index.
forward(x
torch.Tensor, state: Optional[List[torch.Tensor]] = None)
-> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
Compute channel mixing for the input sequences.
######### Examples
>>> ff = FeedForward(size=256, hidden_size=128, block_id=0, num_blocks=4)
>>> input_tensor = torch.randn(32, 10, 256) # (Batch, Sequence, Size)
>>> output, state = ff(input_tensor)
- Raises:ValueError – If the input tensor shape is incorrect.
####### NOTE This module is part of the RWKV architecture, which is designed for efficient sequence modeling tasks.
Construct a FeedForward object.
#
forward(x
Feed-forward (channel mixing) module for RWKV block.
This module is based on and modified from the implementation found at https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py. Some variables have been renamed according to the Hugging Face Transformers implementation at https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
time_shift
A zero padding layer for time shifting.
- Type: torch.nn.ZeroPad2d
time_mix_key
Parameter for mixing keys over time.
- Type: torch.nn.Parameter
time_mix_receptance
Parameter for mixing receptance over time.
- Type: torch.nn.Parameter
proj_key
Linear transformation for keys.
- Type: torch.nn.Linear
proj_value
Linear transformation for values.
- Type: torch.nn.Linear
proj_receptance
Linear transformation for receptance.
- Type: torch.nn.Linear
block_id
The index of the current block.
Type: int
Parameters:
- size (int) – Input/Output size.
- hidden_size (int) – Hidden size.
- block_id (int) – Block index.
- num_blocks (int) – Total number of blocks in the architecture.
reset_parameters(size
int, block_id: int, num_blocks: int) -> None: Resets the parameters of the FeedForward module.
forward(x
torch.Tensor, state: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: Computes the channel mixing operation.
- Returns:
- x: FeedForward output sequences with shape (B, U, size).
- state: Updated decoder hidden state, shape [5 x (B, 1, size, N)].
- Return type: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
######### Examples
>>> ff = FeedForward(size=256, hidden_size=512, block_id=0, num_blocks=2)
>>> input_tensor = torch.randn(32, 10, 256) # (B, U, size)
>>> output, state = ff.forward(input_tensor)
####### NOTE The state parameter is optional. If provided, it should contain the decoder hidden state for the current block.
- Raises:
- ValueError – If the input tensor dimensions do not match the expected
- size. –
#
reset_parameters(size
Reset module parameters.
This method initializes the parameters of the FeedForward module based on the provided size, block_id, and the total number of blocks in the architecture. The parameters are set using a power function of a time weight tensor, which helps in controlling the influence of different time steps during training.
- Parameters:
- size – Block size, which determines the dimensions of the input and output.
- block_id – The index of the current block in the architecture.
- num_blocks – The total number of blocks in the architecture.
####### NOTE The time mixing parameters are initialized using a ratio that scales according to the block index, which helps in managing the temporal dynamics of the model.
######### Examples
>>> ff = FeedForward(size=128, hidden_size=256, block_id=0, num_blocks=4)
>>> ff.reset_parameters(size=128, block_id=0, num_blocks=4)