Topics

A Transformer decoder block is a key component of the Transformer’s decoder, responsible for generating the output sequence. Similar to the encoder, the decoder consists of a stack of identical layers. Each decoder layer has 3 main sub-layers, each followed by a residual connection and layer normalization step. Primary job is to generate output sequences based on the encoded inputs and previously generated outputs. The structure is:

  1. masked multi-head self-attention: The mask prevents the decoder from attending to subsequent tokens in the target sequence, ensuring that the prediction for the current position only depends on the tokens generated so far. This is essential for auto-regressive property. Padding masks are also used here
  2. Add & Norm: A residual connection and layer normalization are applied around the masked multi-head self-attention sub-layer
  3. Encoder-Decoder Multi-Head Attention: The second sub-layer is another multi-head attention mechanism, but this time, the queries () come from the output of the previous decoder sub-layer, while the keys () and values () come from the output of the final transformer encoder block (often referred to as the “memory”). This type of attention is also called as cross-attention. It allows the decoder to attend to the encoded input sequence and draw relevant information for generating the target sequence. Padding masks for the encoder output are applied here.
  4. Second Add & Norm: Another residual connection and layer normalization are applied around the encoder-decoder attention sub-layer.
  5. Position-wise Feed-Forward Network: The third sub-layer is a position wise feed forward networks, identical in structure to the one in the encoder (2 linear layer + ReLU activation)
  6. Third Add & Norm: A final residual connection and layer normalization are applied around the position-wise feed-forward network

The data flow in a decoder layer is: Input Masked Multi-Head Self-Attention Add & Norm Encoder-Decoder Multi-Head Attention Add & Norm Position-wise FFN Add & Norm Output.

|480

Key features:

Example

In machine translation, for translating The cat sat on the mat to French:

  • Decoder first outputs Le based on the encoded input English sentence
  • Embed previous outputs (Le in this case) along with positional encoding
  • Use masked self-attention to focus on relevant outputs so far
  • Use cross attention to focus on relevant encoded inputs (from encoder block)
  • Generates chat considering all prev outputs (Le) and the encoded inputs
  • Generation continues until eos (end of sentence) token