update packing

This commit is contained in:
hiyouga
2024-07-04 01:10:55 +08:00
parent a36e8f2dd5
commit cce7083024
6 changed files with 133 additions and 271 deletions

View File

@@ -29,20 +29,22 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
e.g.
```
[1, 1, 2, 2, 2, 0]
[[1, 1, 2, 2, 2, 0]]
```
->
```
[[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, o, x, x, x],
]
]]
[
[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, o, x, x, x],
]
]
]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""