The current ESMC implementation has pad tokens attending to each other, which does not effect non pad tokens but does result in different final pad hidden states (as well as potentially waste a lot of computation).
The current mask calculation:
mask_BLL = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask_BHLL = mask_BLL.unsqueeze(1)
This short script illustrates the issue and fix:
import torch
VOCAB_SIZE = 64
PAD_TOKEN = 0
input_ids_1 = torch.randint(0, VOCAB_SIZE, (1, 6))
input_ids_2 = torch.randint(0, VOCAB_SIZE, (1, 6))
input_ids_2[:,-3:] = PAD_TOKEN
batch = torch.cat([input_ids_1, input_ids_2], dim=0)
seq_id = batch != PAD_TOKEN
print("2D attention mask:")
print(seq_id)
print("4D attention mask from ESM repo:")
mask = seq_id.unsqueeze(-1) == seq_id.unsqueeze(-2)
mask = mask.unsqueeze(1)
print(mask)
print(mask.shape)
print("A correct 4D attention mask:")
correct_mask = seq_id[:, None, :, None] & seq_id[:, None, None, :]
print(correct_mask)
print(correct_mask.shape)
2D attention mask:
tensor([[ True, True, True, True, True, True],
[ True, True, True, False, False, False]])
Correct, first element of the batch has no pad tokens, everything is attended to
4D attention mask from ESM repo:
tensor([[[[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True]]],
Correct, second element has 3 pad tokens which are ignored
[[[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
Incorrect, pad tokens are attending to themselves instead of nothing
[False, False, False, True, True, True],
[False, False, False, True, True, True],
[False, False, False, True, True, True]]]])
(batch_size, 1, seq_len, seq_len) shape, which is correct
Correct, first element of the batch has no pad tokens, everything is attended to
A correct 4D attention mask:
tensor([[[[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True],
[ True, True, True, True, True, True]]],
Correct, second element has 3 pad tokens which are ignored
[[[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
[ True, True, True, False, False, False],
Correct, pad tokens attend to nothing
[False, False, False, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False]]]])
(batch_size, 1, seq_len, seq_len) shape, which is correct
The current ESMC implementation has pad tokens attending to each other, which does not effect non pad tokens but does result in different final pad hidden states (as well as potentially waste a lot of computation).
The current mask calculation:
This short script illustrates the issue and fix:
Correct, first element of the batch has no pad tokens, everything is attended to
Correct, second element has 3 pad tokens which are ignored
Incorrect, pad tokens are attending to themselves instead of nothing
(batch_size, 1, seq_len, seq_len)shape, which is correctCorrect, first element of the batch has no pad tokens, everything is attended to
Correct, second element has 3 pad tokens which are ignored
Correct, pad tokens attend to nothing
(batch_size, 1, seq_len, seq_len)shape, which is correct