Topics
nn.Embedding and nn.Linear essentially work the same but there are few reasons to prefer nn.Embedding for lookup:
- Implementation is efficient
- For
nn.Linear, lot of wasteful multiplications by 0 - Memory acess pattern: For
nn.Linear, entire weight matrix is loaded in memory during forward pass - During backpopagation, entire weight matrix is updated (most updates are 0 though)question