ESPnet3 Model Configuration (Training)
ESPnet3 Model Configuration (Training)
This page explains how model and task in train.yaml map to model construction for the train / collect_stats stages.
Two modes: task (ESPnet2) vs model.target (custom)
Use ESPnet2-style models (task)
If you want to reuse an ESPnet2-derived model stack, set task and use an ESPnet2-style model: block.
task: espnet3.systems.asr.task.ASRTask
model:
encoder: transformer
decoder: transformer
# ...ESPnet2-style config...Tip: you can start from existing ESPnet2 configs under egs2/*/*/conf/*.yaml. See the ESPnet2 task reference for task names and links to the corresponding recipe docs.
Typical ASR model: keys in ESPnet2 configs:
| Key | Purpose |
|---|---|
encoder / encoder_conf | Encoder type and settings. |
decoder / decoder_conf | Decoder type and settings. |
model / model_conf | ASR model head and loss settings (CTC/attention, etc.). |
frontend / frontend_conf | Feature extraction (e.g., STFT/FBANK). |
specaug / specaug_conf | SpecAugment settings. |
normalize / normalize_conf | Feature normalization (e.g., global MVN). |
ESPnet2 task reference
Below is a quick reference to ESPnet2 task names and their recipe docs.
| Task | Description |
|---|---|
asr1 | Automatic Speech Recognition (Multi-tasking) |
asr2 | Automatic Speech Recognition with Discrete Units |
asvspoof1 | Speaker Verification Spoofing and Countermeasures |
cls1 | Classification |
codec1 | Speech Codec |
diar1 | Speaker Diarisation |
enh1 | Speech Enhancement |
enh_asr1 | Speech Recognition with Speech Enhancement |
enh_diar1 | Speaker Diarisation with Speech Enhancement |
enh_st1 | Speech-to-Text Translation with Speech Enhancement |
hubert1 | Self-supervised Learning |
lid1 | Language Identification |
lm1 | Language Modeling |
mt1 | Machine Translation |
s2st1 | Speech-to-Speech Translation |
s2t1 | Weakly-supervised Learning (Speech-to-Text) |
sds1 | ESPnet-SDS |
slu1 | Spoken Language Understanding |
speechlm1 | Speech Language Model |
spk1 | Speaker Representation |
ssl1 | Self-supervised Learning |
st1 | Speech-to-Text Translation |
svs1 | Singing Voice Synthesis |
svs2 | ESPnet2 SVS2 Recipe TEMPLATE |
tts1 | Text-to-Speech |
tts2 | Text-to-Speech with Discrete Units |
uasr1 | Unsupervised Automatic Speech Recognition |
Use custom/ESPnet3-only models (model._target_)
If you want an ESPnet3-specific or fully custom model, implement it under your recipe's src/ directory and point model._target_ to it:
model:
_target_: src.my_model.MyModel
# custom args hereTraining-time forward contract
ESPnetLightningModule calls model(**batch) on every training and validation step. Your model's forward must return exactly three values in this order:
loss, stats, weight = model(**batch)This is enforced β returning the wrong structure raises an error inside the training loop.
loss β scalar tensor for backprop
In the standard single-optimizer path, loss must be a scalar torch.Tensor that requires grad. Lightning takes this return value and calls .backward() on it automatically.
loss = compute_ctc_loss(...) # scalar, requires_grad=Truestats β the logging dict
stats must be a dict[str, Tensor]. The training loop logs every entry directly to TensorBoard / W&B under the key {mode}/{key}, where mode is train or valid.
stats = {
"loss": loss.detach(), # β logged as "train/loss" / "valid/loss"
"acc": acc.detach(), # β logged as "train/acc" / "valid/acc"
}Every value must be detached. Forgetting .detach() keeps the gradient graph alive across steps and wastes GPU memory. It does not raise an error, so the mistake is easy to miss.
weight β batch size for weighted averaging
weight is a scalar tensor representing the number of samples in the batch. The training loop passes it as batch_size to Lightning's log_dict, which enables correct weighted averaging across variable-length batches and DDP workers.
weight = speech.new_tensor(speech.shape[0]) # same device as lossUse tensor.new_tensor(value) rather than torch.tensor(value, device=loss.device) to match the device of an existing tensor automatically.
Full example
class MyASRModel(torch.nn.Module):
def forward(self, speech, speech_lengths, text, text_lengths, **kwargs):
loss, acc = self._compute_loss(
speech, speech_lengths, text, text_lengths
)
stats = {
"loss": loss.detach(),
"acc": acc.detach(),
}
weight = speech.new_tensor(speech.shape[0])
return loss, stats, weightNaN / Inf handling
If loss is NaN or Inf the entire batch is skipped across all DDP workers at once. After 100 consecutive NaN batches the training loop raises RuntimeError and stops.
Multi-optimizer (GAN-style) training
For GAN-style or other multi-optimizer training, loss is replaced by one or more OptimizationStep objects that route each loss to the correct optimizer. stats and weight keep the same structure; weight may be None.
from espnet3.components.modeling.optimization_spec import OptimizationStep
def forward(self, **batch):
g_loss = ...
d_loss = ...
stats = {
"generator_loss": g_loss.detach(),
"discriminator_loss": d_loss.detach(),
}
return [
OptimizationStep(loss=g_loss, name="generator"),
OptimizationStep(loss=d_loss, name="discriminator"),
], stats, NoneThe training loop automatically logs train/generator/loss, train/discriminator/loss, and per-optimizer update steps in addition to the keys in stats. Only the optimizers named in the returned steps are updated for that batch; others are left untouched.
See Optimizer Configuration for YAML wiring and per-optimizer gradient clipping.
Collect-stats support (collect_feats)
If you want to use collect_stats, your model should implement collect_feats(). See:
- Stage doc:
doc/vuepress/src/espnet3/stages/collect-stats.html - Config doc:
doc/vuepress/src/espnet3/core/config/training.html
