mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
@@ -12,9 +12,105 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import os
|
||||
|
||||
from llamafactory.data.collator import prepare_4d_attention_mask
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
|
||||
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||
|
||||
|
||||
def test_base_collator():
|
||||
model_args, data_args, *_ = get_infer_args({"model_name_or_path": TINY_LLAMA, "template": "default"})
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
p = tokenizer_module["tokenizer"].pad_token_id
|
||||
q = IGNORE_INDEX
|
||||
features = [
|
||||
{
|
||||
"input_ids": [0, 1, 2, 3, 4, 5],
|
||||
"attention_mask": [1, 1, 1, 1, 1, 1],
|
||||
"labels": [q, q, 2, 3, 4, 5],
|
||||
},
|
||||
{
|
||||
"input_ids": [6, 7],
|
||||
"attention_mask": [1, 1],
|
||||
"labels": [q, 7],
|
||||
},
|
||||
]
|
||||
batch_input = data_collator(features)
|
||||
expected_input = {
|
||||
"input_ids": [
|
||||
[0, 1, 2, 3, 4, 5, p, p],
|
||||
[6, 7, p, p, p, p, p, p],
|
||||
],
|
||||
"attention_mask": [
|
||||
[1, 1, 1, 1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
"labels": [
|
||||
[q, q, 2, 3, 4, 5, q, q],
|
||||
[q, 7, q, q, q, q, q, q],
|
||||
],
|
||||
}
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
|
||||
def test_multimodal_collator():
|
||||
model_args, data_args, *_ = get_infer_args(
|
||||
{"model_name_or_path": "Qwen/Qwen2-VL-7B-Instruct", "template": "qwen2_vl"}
|
||||
)
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
|
||||
data_collator = MultiModalDataCollatorForSeq2Seq(
|
||||
template=template,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
**tokenizer_module,
|
||||
)
|
||||
p = tokenizer_module["tokenizer"].pad_token_id
|
||||
q = IGNORE_INDEX
|
||||
s = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_start|>")
|
||||
e = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|vision_end|>")
|
||||
m = tokenizer_module["tokenizer"].convert_tokens_to_ids("<|image_pad|>")
|
||||
fake_image = Image.new("RGB", (64, 64), (255, 255, 255))
|
||||
|
||||
features = [
|
||||
{
|
||||
"input_ids": [0, 1, 2, 3],
|
||||
"attention_mask": [1, 1, 1, 1],
|
||||
"labels": [0, 1, 2, 3],
|
||||
},
|
||||
]
|
||||
batch_input = data_collator(features)
|
||||
expected_input = {
|
||||
"input_ids": [
|
||||
[0, 1, 2, 3, s, m, m, m, m, e, p, p],
|
||||
],
|
||||
"attention_mask": [
|
||||
[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
],
|
||||
"labels": [
|
||||
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
||||
],
|
||||
**tokenizer_module["processor"].image_processor(fake_image),
|
||||
}
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
|
||||
def test_4d_attention_mask():
|
||||
|
||||
Reference in New Issue
Block a user