fix jinja template

This commit is contained in:
hiyouga
2024-06-19 20:03:50 +08:00
parent 4cff6a4ad5
commit 2b596fb55f
3 changed files with 46 additions and 4 deletions

View File

@@ -17,6 +17,7 @@ import random
import pytest
from datasets import load_dataset
from transformers import AutoTokenizer
from llamafactory.data import get_dataset
from llamafactory.hparams import get_train_args
@@ -48,10 +49,11 @@ def test_supervised(num_samples: int):
tokenizer = tokenizer_module["tokenizer"]
tokenized_data = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)
ref_tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
original_data = load_dataset(TRAIN_ARGS["dataset"], split="train")
indexes = random.choices(range(len(original_data)), k=num_samples)
for index in indexes:
decoded_result = tokenizer.decode(tokenized_data["input_ids"][index])
prompt = original_data[index]["instruction"]
if original_data[index]["input"]:
prompt += "\n" + original_data[index]["input"]
@@ -60,5 +62,6 @@ def test_supervised(num_samples: int):
{"role": "user", "content": prompt},
{"role": "assistant", "content": original_data[index]["output"]},
]
templated_result = tokenizer.apply_chat_template(messages, tokenize=False)
assert decoded_result == templated_result
templated_result = ref_tokenizer.apply_chat_template(messages, tokenize=False)
decoded_result = tokenizer.decode(tokenized_data["input_ids"][index])
assert templated_result == decoded_result