espnet2.mt.espnet_model.ESPnetMTModel
espnet2.mt.espnet_model.ESPnetMTModel
class espnet2.mt.espnet_model.ESPnetMTModel(vocab_size: int, token_list: Tuple[str, ...] | List[str], frontend: AbsFrontend | None, preencoder: AbsPreEncoder | None, encoder: AbsEncoder, postencoder: AbsPostEncoder | None, decoder: AbsDecoder, src_vocab_size: int = 0, src_token_list: Tuple[str, ...] | List[str] = [], ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_bleu: bool = True, sym_space: str = '<space>', sym_blank: str = '<blank>', patch_size: int = 1, extract_feats_in_collect_stats: bool = True, share_decoder_input_output_embed: bool = False, share_encoder_decoder_input_embed: bool = False)
Bases: AbsESPnetModel
Encoder-Decoder model for machine translation.
This class implements an encoder-decoder architecture specifically for machine translation tasks. It allows for various configurations of frontends, encoders, decoders, and post-encoders. The model can also compute loss using label smoothing and track BLEU scores during evaluation.
sos
Start of sequence token ID.
- Type: int
eos
End of sequence token ID.
- Type: int
src_sos
Source start of sequence token ID.
- Type: int
src_eos
Source end of sequence token ID.
- Type: int
vocab_size
Size of the target vocabulary.
- Type: int
src_vocab_size
Size of the source vocabulary.
- Type: int
ignore_id
Token ID to ignore during loss computation.
- Type: int
patch_size
Size of the patch for feature extraction.
- Type: int
token_list
List of tokens for BLEU score calculation.
- Type: List[str]
frontend
Frontend module for feature extraction.
- Type:AbsFrontend
preencoder
Pre-encoder module.
- Type:AbsPreEncoder
postencoder
Post-encoder module.
- Type:AbsPostEncoder
encoder
Encoder module.
- Type:AbsEncoder
decoder
Decoder module.
- Type:AbsDecoder
criterion_mt
Loss function for machine translation.
- Type: LabelSmoothingLoss
mt_error_calculator
Calculator for BLEU scores.
- Type: MTErrorCalculator
extract_feats_in_collect_stats
Flag to extract features.
Type: bool
Parameters:
- vocab_size (int) – Size of the target vocabulary.
- token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ]) – List of target tokens.
- frontend (Optional [AbsFrontend ]) – Frontend module for feature extraction.
- preencoder (Optional [AbsPreEncoder ]) – Pre-encoder module.
- encoder (AbsEncoder) – Encoder module.
- postencoder (Optional [AbsPostEncoder ]) – Post-encoder module.
- decoder (AbsDecoder) – Decoder module.
- src_vocab_size (int , optional) – Size of the source vocabulary. Default is 0.
- src_token_list (Union *[*Tuple *[*str , ... ] , List *[*str ] ] , optional) – List of source tokens. Default is an empty list.
- ignore_id (int , optional) – Token ID to ignore during loss computation. Default is -1.
- lsm_weight (float , optional) – Weight for label smoothing. Default is 0.0.
- length_normalized_loss (bool , optional) – Whether to use length-normalized loss. Default is False.
- report_bleu (bool , optional) – Whether to report BLEU score. Default is True.
- sym_space (str , optional) – Symbol for space. Default is “<space>”.
- sym_blank (str , optional) – Symbol for blank. Default is “<blank>”.
- patch_size (int , optional) – Size of the patch for feature extraction. Default is 1.
- extract_feats_in_collect_stats (bool , optional) – Flag to extract features in collect stats. Default is True.
- share_decoder_input_output_embed (bool , optional) – Whether to share decoder input and output embeddings. Default is False.
- share_encoder_decoder_input_embed (bool , optional) – Whether to share encoder and decoder input embeddings. Default is False.
######### Examples
Create a machine translation model
model = ESPnetMTModel(
vocab_size=5000, token_list=[“<blank>”, “<space>”, “hello”, “world”], frontend=None, preencoder=None, encoder=SomeEncoder(), postencoder=None, decoder=SomeDecoder(), src_vocab_size=3000, src_token_list=[“<blank>”, “<space>”, “hola”, “mundo”], ignore_id=-1, lsm_weight=0.1, length_normalized_loss=True, report_bleu=True, sym_space=”<space>”, sym_blank=”<blank>”, patch_size=1, extract_feats_in_collect_stats=True, share_decoder_input_output_embed=False, share_encoder_decoder_input_embed=False,
)
NOTE
The forward method computes the output and loss for a given input batch. The encode method extracts features from the source text and passes them through the encoder.
- Raises:
- AssertionError – If the input tensor dimensions do not match expected
- dimensions. –
Initialize internal Module state, shared by both nn.Module and ScriptModule.
collect_feats(text: Tensor, text_lengths: Tensor, src_text: Tensor, src_text_lengths: Tensor, **kwargs) → Dict[str, Tensor]
Collect features from the source text and return them along with their lengths.
This method extracts features from the source text tensor and returns them in a dictionary. If the extract_feats_in_collect_stats attribute is set to True, actual features are extracted; otherwise, dummy features are generated.
- Parameters:
- text (torch.Tensor) – The target text tensor of shape (Batch, Length).
- text_lengths (torch.Tensor) – The lengths of the target text of shape (Batch,).
- src_text (torch.Tensor) – The source text tensor of shape (Batch, Length).
- src_text_lengths (torch.Tensor) – The lengths of the source text of shape (Batch,).
- **kwargs – Additional keyword arguments, if any.
- Returns: A dictionary containing: : - ”feats” (torch.Tensor): The extracted features.
- ”feats_lengths” (torch.Tensor): The lengths of the extracted features.
- Return type: Dict[str, torch.Tensor]
- Raises:AssertionError – If the src_text_lengths tensor does not have a dimension of 1.
######### Examples
>>> model = ESPnetMTModel(vocab_size=1000, token_list=['<blank>', '<sos>', '<eos>'],
... encoder=encoder, decoder=decoder)
>>> text = torch.randint(0, 1000, (32, 20))
>>> text_lengths = torch.randint(1, 21, (32,))
>>> src_text = torch.randint(0, 1000, (32, 15))
>>> src_text_lengths = torch.randint(1, 16, (32,))
>>> feats = model.collect_feats(text, text_lengths, src_text,
... src_text_lengths)
>>> print(feats["feats"].shape) # Output: (Batch, NSamples, Dim)
NOTE
This method is particularly useful for feature extraction during model evaluation or inference.
encode(src_text: Tensor, src_text_lengths: Tensor) → Tuple[Tensor, Tensor]
Frontend + Encoder. Note that this method is used by mt_inference.py
- Parameters:
- src_text – (Batch, Length, …)
- src_text_lengths – (Batch, )
forward(text: Tensor, text_lengths: Tensor, src_text: Tensor, src_text_lengths: Tensor, **kwargs) → Tuple[Tensor, Dict[str, Tensor], Tensor]
Frontend + Encoder + Decoder + Calc loss
This method processes the input through the frontend, encoder, and decoder to calculate the loss for the given input text and source text.
- Parameters:
- text – A tensor of shape (Batch, Length) representing the target text.
- text_lengths – A tensor of shape (Batch,) containing the lengths of each sequence in the target text.
- src_text – A tensor of shape (Batch, Length) representing the source text.
- src_text_lengths – A tensor of shape (Batch,) containing the lengths of each sequence in the source text.
- kwargs – Additional arguments, where “utt_id” is among the input.
- Returns:
- loss: A tensor representing the calculated loss.
- stats: A dictionary with keys ‘loss’, ‘acc’, and ‘bleu’, containing : the corresponding statistics.
- weight: A tensor representing the batch size for gathering.
- Return type: A tuple containing
- Raises:AssertionError – If the dimensions of input tensors do not match.
######### Examples
>>> model = ESPnetMTModel(...)
>>> text = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> text_lengths = torch.tensor([3, 3])
>>> src_text = torch.tensor([[7, 8, 9], [10, 11, 12]])
>>> src_text_lengths = torch.tensor([3, 3])
>>> loss, stats, weight = model.forward(text, text_lengths, src_text,
... src_text_lengths)