update get template

This commit is contained in:
hiyouga
2024-09-04 22:36:20 +08:00
parent 8f441c2b3a
commit dabad5570b
17 changed files with 57 additions and 56 deletions

View File

@@ -18,7 +18,7 @@ from collections import defaultdict
import fire
from tqdm import tqdm
from llamafactory.data import get_dataset
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer
@@ -48,7 +48,8 @@ def length_cdf(
)
)
tokenizer_module = load_tokenizer(model_args)
trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)["train_dataset"]
template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)
trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"]
total_num = len(trainset)
length_dict = defaultdict(int)
for sample in tqdm(trainset["input_ids"]):