From d5ea05cfff55de5d736ef4e70099794031652eb6 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 4 Sep 2024 22:36:20 +0800 Subject: [PATCH] update get template Former-commit-id: dabad5570bf4a6b1044c963d8f27717030f373ef --- scripts/cal_lr.py | 5 +++-- scripts/cal_ppl.py | 5 +++-- scripts/length_cdf.py | 5 +++-- src/llamafactory/chat/hf_engine.py | 2 +- src/llamafactory/chat/vllm_engine.py | 2 +- src/llamafactory/data/loader.py | 18 ++++++------------ src/llamafactory/data/template.py | 26 +++++++++++++------------- src/llamafactory/eval/evaluator.py | 2 +- src/llamafactory/hparams/data_args.py | 4 ---- src/llamafactory/train/dpo/workflow.py | 5 +++-- src/llamafactory/train/kto/workflow.py | 5 +++-- src/llamafactory/train/ppo/workflow.py | 5 +++-- src/llamafactory/train/pt/workflow.py | 5 +++-- src/llamafactory/train/rm/workflow.py | 5 +++-- src/llamafactory/train/sft/workflow.py | 7 ++++--- src/llamafactory/train/test_utils.py | 5 +++-- tests/data/test_template.py | 7 ++++--- 17 files changed, 57 insertions(+), 56 deletions(-) diff --git a/scripts/cal_lr.py b/scripts/cal_lr.py index a9b27b37..9f5737f5 100644 --- a/scripts/cal_lr.py +++ b/scripts/cal_lr.py @@ -25,7 +25,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq -from llamafactory.data import get_dataset +from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.hparams import get_train_args from llamafactory.model import load_tokenizer @@ -66,7 +66,8 @@ def calculate_lr( ) tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] if stage == "pt": data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) elif stage == "sft": diff --git a/scripts/cal_ppl.py b/scripts/cal_ppl.py index 1a5f9034..14b5d1ba 100644 --- a/scripts/cal_ppl.py +++ b/scripts/cal_ppl.py @@ -23,7 +23,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq -from llamafactory.data import get_dataset +from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.hparams import get_train_args from llamafactory.model import load_model, load_tokenizer @@ -88,7 +88,8 @@ def cal_ppl( ) tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - trainset = get_dataset(model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + trainset = get_dataset(template, model_args, data_args, training_args, stage, **tokenizer_module)["train_dataset"] model = load_model(tokenizer, model_args, finetuning_args, is_trainable=False) if stage == "pt": data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) diff --git a/scripts/length_cdf.py b/scripts/length_cdf.py index 65a51872..d8885731 100644 --- a/scripts/length_cdf.py +++ b/scripts/length_cdf.py @@ -18,7 +18,7 @@ from collections import defaultdict import fire from tqdm import tqdm -from llamafactory.data import get_dataset +from llamafactory.data import get_dataset, get_template_and_fix_tokenizer from llamafactory.hparams import get_train_args from llamafactory.model import load_tokenizer @@ -48,7 +48,8 @@ def length_cdf( ) ) tokenizer_module = load_tokenizer(model_args) - trainset = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module)["train_dataset"] + template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) + trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"] total_num = len(trainset) length_dict = defaultdict(int) for sample in tqdm(trainset["input_ids"]): diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 9817318d..880e5803 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -54,7 +54,7 @@ class HuggingfaceEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" if self.can_generate else "right" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.model = load_model( self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) ) # must after fixing tokenizer to resize vocab diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index ed4613d0..7d34965a 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -68,7 +68,7 @@ class VllmEngine(BaseEngine): self.tokenizer = tokenizer_module["tokenizer"] self.processor = tokenizer_module["processor"] self.tokenizer.padding_side = "left" - self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template, data_args.tool_format) + self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args) self.generating_args = generating_args.to_dict() engine_args = { diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index be37f38f..ff042f4f 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -14,7 +14,7 @@ import os import sys -from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union import numpy as np from datasets import DatasetDict, load_dataset, load_from_disk @@ -27,7 +27,6 @@ from .aligner import align_dataset from .data_utils import merge_dataset, split_dataset from .parser import get_dataset_list from .preprocess import get_preprocess_and_print_func -from .template import get_template_and_fix_tokenizer if TYPE_CHECKING: @@ -179,9 +178,6 @@ def _get_preprocessed_dataset( load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), desc="Running tokenizer on dataset", ) - if data_args.dataset_map_batch_size: - # Set the batch size conditionally without considering the default variable of the batch size in the map function - kwargs.update(batch_size=data_args.dataset_map_batch_size) dataset = dataset.map( preprocess_func, @@ -205,17 +201,14 @@ def _get_preprocessed_dataset( def get_dataset( + template: "Template", model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, -) -> Tuple["DatasetModule", "Template"]: - template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) - if data_args.train_on_prompt and template.efficient_eos: - raise ValueError("Current template does not support `train_on_prompt`.") - +) -> "DatasetModule": # Load tokenized dataset if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): @@ -233,7 +226,7 @@ def get_dataset( if data_args.streaming: dataset_module = {k: v.to_iterable_dataset() for k, v in dataset_module.items()} - return dataset_module, template + return dataset_module if data_args.streaming: raise ValueError("Turn off `streaming` when saving dataset to disk.") @@ -280,7 +273,8 @@ def get_dataset( dataset_module = {} if "train" in dataset_dict: dataset_module["train_dataset"] = dataset_dict["train"] + if "validation" in dataset_dict: dataset_module["eval_dataset"] = dataset_dict["validation"] - return dataset_module, template + return dataset_module diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index a4d62f58..818e5625 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -27,6 +27,7 @@ from .mm_plugin import get_mm_plugin if TYPE_CHECKING: from transformers import PreTrainedTokenizer + from ..hparams import DataArguments from .formatter import SLOTS, Formatter from .mm_plugin import BasePlugin @@ -344,28 +345,27 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") return jinja_template -def get_template_and_fix_tokenizer( - tokenizer: "PreTrainedTokenizer", - name: Optional[str] = None, - tool_format: Optional[str] = None, -) -> Template: - if name in ["llava", "paligemma", "qwen2_vl"]: +def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": + if data_args.template in ["llava", "paligemma", "qwen2_vl"]: require_version( "transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" ) - if name is None: + if data_args.template is None: template = TEMPLATES["empty"] # placeholder else: - template = TEMPLATES.get(name, None) + template = TEMPLATES.get(data_args.template, None) if template is None: - raise ValueError("Template {} does not exist.".format(name)) + raise ValueError("Template {} does not exist.".format(data_args.template)) - if tool_format is not None: - logger.info("Using tool format: {}.".format(tool_format)) + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + + if data_args.tool_format is not None: + logger.info("Using tool format: {}.".format(data_args.tool_format)) eos_slots = [] if template.efficient_eos else [{"eos_token"}] - template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) - template.format_tools = ToolFormatter(tool_format=tool_format) + template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format) + template.format_tools = ToolFormatter(tool_format=data_args.tool_format) stop_words = template.stop_words if template.replace_eos: diff --git a/src/llamafactory/eval/evaluator.py b/src/llamafactory/eval/evaluator.py index f05e01a1..fb0aa4a4 100644 --- a/src/llamafactory/eval/evaluator.py +++ b/src/llamafactory/eval/evaluator.py @@ -59,7 +59,7 @@ class Evaluator: self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.tokenizer = load_tokenizer(self.model_args)["tokenizer"] self.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2 - self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args.template) + self.template = get_template_and_fix_tokenizer(self.tokenizer, self.data_args) self.model = load_model(self.tokenizer, self.model_args, finetuning_args) self.eval_template = get_eval_template(self.eval_args.lang) self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index a03128c6..1adcf2d0 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -113,10 +113,6 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) - dataset_map_batch_size: Optional[int] = field( - default=None, - metadata={"help": "Batch size for dataset mapping."}, - ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 5135f5a2..3a8464ec 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import PairwiseDataCollatorWithPadding, get_dataset +from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX from ...extras.ploting import plot_loss from ...hparams import ModelArguments @@ -41,7 +41,8 @@ def run_dpo( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = PairwiseDataCollatorWithPadding( diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index 8d282685..81c30a14 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import KTODataCollatorWithPadding, get_dataset +from ...data import KTODataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX from ...extras.ploting import plot_loss from ...hparams import ModelArguments @@ -41,7 +41,8 @@ def run_kto( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module, template = get_dataset(model_args, data_args, training_args, stage="kto", **tokenizer_module) + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="kto", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = KTODataCollatorWithPadding( diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 7fa5c252..92262f99 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset +from ...data import MultiModalDataCollatorForSeq2Seq, get_dataset, get_template_and_fix_tokenizer from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..callbacks import fix_valuehead_checkpoint @@ -41,7 +41,8 @@ def run_ppo( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module, template = get_dataset(model_args, data_args, training_args, stage="ppo", **tokenizer_module) + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="ppo", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 91c66fa9..06afdc12 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, List, Optional from transformers import DataCollatorForLanguageModeling -from ...data import get_dataset +from ...data import get_dataset, get_template_and_fix_tokenizer from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push @@ -42,7 +42,8 @@ def run_pt( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module, _ = get_dataset(model_args, data_args, training_args, stage="pt", **tokenizer_module) + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index 9adf5827..e3f1b762 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import PairwiseDataCollatorWithPadding, get_dataset +from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..callbacks import fix_valuehead_checkpoint @@ -41,7 +41,8 @@ def run_rm( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module, template = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module) + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True) data_collator = PairwiseDataCollatorWithPadding(template=template, pad_to_multiple_of=8, **tokenizer_module) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index a577e879..43a9aef1 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional -from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset +from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss @@ -43,7 +43,8 @@ def run_sft( ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] - dataset_module, template = get_dataset(model_args, data_args, training_args, stage="sft", **tokenizer_module) + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) if getattr(model, "is_quantized", False) and not training_args.do_train: @@ -62,7 +63,7 @@ def run_sft( # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams - training_args.remove_unused_columns = False # important for multimodal and pairwise dataset + training_args.remove_unused_columns = False # important for multimodal dataset # Metric utils metric_module = {} diff --git a/src/llamafactory/train/test_utils.py b/src/llamafactory/train/test_utils.py index 2a3b3eee..649a4795 100644 --- a/src/llamafactory/train/test_utils.py +++ b/src/llamafactory/train/test_utils.py @@ -19,7 +19,7 @@ from peft import PeftModel from transformers import AutoModelForCausalLM from trl import AutoModelForCausalLMWithValueHead -from ..data import get_dataset +from ..data import get_dataset, get_template_and_fix_tokenizer from ..extras.misc import get_current_device from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer @@ -105,7 +105,8 @@ def load_reference_model( def load_train_dataset(**kwargs) -> "Dataset": model_args, data_args, training_args, _, _ = get_train_args(kwargs) tokenizer_module = load_tokenizer(model_args) - dataset_module, _ = get_dataset(model_args, data_args, training_args, stage=kwargs["stage"], **tokenizer_module) + template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, kwargs["stage"], **tokenizer_module) return dataset_module["train_dataset"] diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 360163e7..0c5e669e 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -19,6 +19,7 @@ import pytest from transformers import AutoTokenizer from llamafactory.data import get_template_and_fix_tokenizer +from llamafactory.hparams import DataArguments if TYPE_CHECKING: @@ -51,7 +52,7 @@ def _check_single_template( tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, token=HF_TOKEN) content_str = tokenizer.apply_chat_template(MESSAGES, tokenize=False) content_ids = tokenizer.apply_chat_template(MESSAGES, tokenize=True) - template = get_template_and_fix_tokenizer(tokenizer, name=template_name) + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template=template_name)) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) assert content_str == prompt_str + answer_str + extra_str assert content_ids == prompt_ids + answer_ids + tokenizer.encode(extra_str, add_special_tokens=False) @@ -78,7 +79,7 @@ def _check_template(model_id: str, template_name: str, prompt_str: str, answer_s @pytest.mark.parametrize("use_fast", [True, False]) def test_encode_oneturn(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) - template = get_template_and_fix_tokenizer(tokenizer, name="llama3") + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) prompt_str = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" @@ -93,7 +94,7 @@ def test_encode_oneturn(use_fast: bool): @pytest.mark.parametrize("use_fast", [True, False]) def test_encode_multiturn(use_fast: bool): tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA, use_fast=use_fast) - template = get_template_and_fix_tokenizer(tokenizer, name="llama3") + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) prompt_str_1 = ( "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"