espnet2.legacy.nets.pytorch_backend.tacotron2.decoder.ZoneOutCell
Less than 1 minute
espnet2.legacy.nets.pytorch_backend.tacotron2.decoder.ZoneOutCell
class espnet2.legacy.nets.pytorch_backend.tacotron2.decoder.ZoneOutCell(cell, zoneout_rate=0.1)
Bases: Module
ZoneOut Cell module.
This is a module of zoneout described in
``Zoneout: Regularizing RNNs by Randomly Preserving Hidden Activations``_. This code is modified from
``eladhoffer/seq2seq.pytorch``_.
Examples
>>> lstm = torch.nn.LSTMCell(16, 32)
>>> lstm = ZoneOutCell(lstm, 0.5)Initialize zone out cell module.
- Parameters:
- cell (torch.nn.Module) – Pytorch recurrent cell module e.g.
torch.nn.Module.LSTMCell. - zoneout_rate (float , optional) – Probability of zoneout from 0.0 to 1.0.
- cell (torch.nn.Module) – Pytorch recurrent cell module e.g.
forward(inputs, hidden)
Calculate forward propagation.
- Parameters:
- inputs (Tensor) – Batch of input tensor (B, input_size).
- hidden (tuple) –
- Tensor: Batch of initial hidden states (B, hidden_size).
- Tensor: Batch of initial cell states (B, hidden_size).
- Returns:
- Tensor: Batch of next hidden states (B, hidden_size).
- Tensor: Batch of next cell states (B, hidden_size).
- Return type: tuple
