update packing

This commit is contained in:
hiyouga
2024-07-04 01:10:55 +08:00
parent a36e8f2dd5
commit cce7083024
6 changed files with 133 additions and 271 deletions

View File

@@ -52,4 +52,5 @@ def test_4d_attention_mask():
],
dtype=torch.float16,
)
assert list(attention_mask_computed.size()) == [2, 1, 6, 6]
assert torch.all(attention_mask_computed == attention_mask_expected)