espnet2.legacy.nets.pytorch_backend.rnn.attentions.AttMultiHeadLoc
Less than 1 minute
espnet2.legacy.nets.pytorch_backend.rnn.attentions.AttMultiHeadLoc
class espnet2.legacy.nets.pytorch_backend.rnn.attentions.AttMultiHeadLoc(eprojs, dunits, aheads, att_dim_k, att_dim_v, aconv_chans, aconv_filts, han_mode=False)
Bases: Module
Multi head location based attention.
Reference: Attention is all you need : (https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
- Parameters:
- eprojs (int) β # projection-units of encoder
- dunits (int) β # units of decoder
- aheads (int) β # heads of multi head attention
- att_dim_k (int) β dimension k in multi head attention
- att_dim_v (int) β dimension v in multi head attention
- aconv_chans (int) β # channels of attention convolution
- aconv_filts (int) β filter size of attention convolution
- han_mode (bool) β flag to swith on mode of hierarchical attention and not store pre_compute_k and pre_compute_v
Initialize AttMultiHeadLoc.
forward(enc_hs_pad, enc_hs_len, dec_z, att_prev, scaling=2.0, **kwargs)
Calculate AttMultiHeadLoc forward propagation.
- Parameters:
- enc_hs_pad (torch.Tensor) β padded encoder hidden state (B x T_max x D_enc)
- enc_hs_len (list) β padded encoder hidden state length (B)
- dec_z (torch.Tensor) β decoder hidden state (B x D_dec)
- att_prev (torch.Tensor) β list of previous attention weight (B x T_max) * aheads
- scaling (float) β scaling parameter before applying softmax
- Returns: attention weighted encoder state (B x D_enc)
- Return type: torch.Tensor
- Returns: list of previous attention weight (B x T_max) * aheads
- Return type: list
reset()
Reset states.
