The encoder takes variable-length sequences as input X
. The implementation will be provided by any model that inherits this base Encoder
class Encoder(nn.Module):
"""The base encoder interface for the encoder-decoder architecture."""
def __init__(self):
# Later there can be additional arguments (e.g., length excluding padding)
def forward(self, X, *args):
raise NotImplementedError
The decoder interface has an additional init_state
method to convert the encoder output (enc_all_outputs
) into the encoded state.
To generate a variable-length sequence token by token, every time the decoder may map an input (e.g., the generated token at the previous time step) and the encoded state into an output token at the current time step.
class Decoder(nn.Module):
"""The base decoder interface for the encoder-decoder architecture."""
def __init__(self):
# Later there can be additional arguments (e.g., length excluding padding)
def init_state(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
Putting the base classes together, we have:
class EncoderDecoder(nn.Module):
"""The base class for the encoder--decoder architecture."""
def __init__(self, encoder, decoder):
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]