Add a unit test for BartModel to compare eager, sdpa on one particular set of inputs#39435
Add a unit test for BartModel to compare eager, sdpa on one particular set of inputs#39435xadupre wants to merge 51 commits into
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: bart |
vasqu
left a comment
There was a problem hiding this comment.
Some initial thoughts on the test: Does this only happen if we pass invalid masks? I.e. the mask was too long?
Added some comments on the test design to make it more simple / reuse things we have.
| context = torch.tensor( | ||
| [[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]], device=torch_device, dtype=torch.long | ||
| ) | ||
| mask = torch.ones((context.shape[0], context.shape[1] + 2), device=context.device, dtype=torch.int64) |
There was a problem hiding this comment.
So it fails when we give it a mask that is too long for the original input?
| mask = torch.ones((context.shape[0], context.shape[1] + 2), device=context.device, dtype=torch.int64) | ||
| shape1 = (2, 2, 7, 7) | ||
| shape2 = (2, 2, 2, 7) | ||
| past_key_values = EncoderDecoderCache( |
There was a problem hiding this comment.
I'd rather if we could just pass use_cache in the forward call.
| @parameterized.expand(["sdpa", "eager"]) | ||
| def test_lm_uneven_forward_with_mask(self, attn_implementation): |
There was a problem hiding this comment.
I don't think we need to parametrize this; just set the config accordingly and check one after the other 👀
Other than that, I think it makes sense to move this test under BartModelTester/BartModelTest, you can look at test_encoder_decoder_model_standalone for reference. We should possibly only modify the inputs (attention mask) and keep the rest from the defaults, i.e. config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common().
|
@xadupre Still any interest in this PR/bug? Otherwise, I'd make another PR to apply this fix and propagate it 👀 |
What does this PR do?
Fixes #39365. Not a fix yet but introducing a unit test failing for the reason explained in this issue. Either the inputs are wrong, either the fix from issue #39365 is needed.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.