Topics

nn.Embedding and nn.Linear essentially work the same but there are few reasons to prefer nn.Embedding for lookup:

  • Implementation is efficient
  • For nn.Linear, lot of wasteful multiplications by 0
  • Memory acess pattern: For nn.Linear, entire weight matrix is loaded in memory during forward pass
  • During backpopagation, entire weight matrix is updated (most updates are 0 though)question