Topics
nn.Embedding
and nn.Linear
essentially work the same but there are few reasons to prefer nn.Embedding
for lookup:
- Implementation is efficient
- For
nn.Linear
, lot of wasteful multiplications by 0 - Memory acess pattern: For
nn.Linear
, entire weight matrix is loaded in memory during forward pass - During backpopagation, entire weight matrix is updated (most updates are 0 though)question