From a9b0da597b0d61d113388f4e54989cb203d8c942 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Thu, 25 Apr 2024 22:04:09 +0800 Subject: [PATCH] modify some style Former-commit-id: c27f7fbf62b4a00b5794bb20621a4060a82490b7 --- src/llmtuner/data/preprocess.py | 5 +-- src/llmtuner/data/template.py | 50 ++++++++++++++++++++++++++++-- src/llmtuner/train/sft/workflow.py | 2 +- src/llmtuner/train/tuner.py | 1 + 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/data/preprocess.py b/src/llmtuner/data/preprocess.py index 51af8060..9cdcdfa2 100644 --- a/src/llmtuner/data/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -324,10 +324,7 @@ def get_preprocess_and_print_func( print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) else: preprocess_func = partial( - preprocess_unsupervised_dataset, - tokenizer=tokenizer, - template=template, - data_args=data_args, + preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args ) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index f798ba5a..9a3673c3 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from .formatter import SLOTS, Formatter + logger = get_logger(__name__) @@ -103,9 +104,7 @@ class Template: return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) def _convert_elements_to_ids( - self, - tokenizer: "PreTrainedTokenizer", - elements: List[Union[str, Dict[str, str]]], + self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]] ) -> List[int]: r""" Converts elements to token ids. @@ -391,6 +390,7 @@ _register_template( ), ) + _register_template( name="aquila", format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), @@ -403,6 +403,7 @@ _register_template( efficient_eos=True, ) + _register_template( name="atom", format_user=StringFormatter( @@ -411,18 +412,21 @@ _register_template( format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), ) + _register_template( name="baichuan", format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), efficient_eos=True, ) + _register_template( name="baichuan2", format_user=StringFormatter(slots=["{{content}}"]), efficient_eos=True, ) + _register_template( name="belle", format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), @@ -431,11 +435,13 @@ _register_template( force_system=True, ) + _register_template( name="bluelm", format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), ) + _register_template( name="breeze", format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), @@ -447,6 +453,7 @@ _register_template( efficient_eos=True, ) + _register_template( name="chatglm2", format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), @@ -456,6 +463,7 @@ _register_template( force_system=True, ) + _register_template( name="chatglm3", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), @@ -470,6 +478,7 @@ _register_template( force_system=True, ) + _register_template( name="chatglm3_system", format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), @@ -489,6 +498,7 @@ _register_template( efficient_eos=True, ) + _register_template( name="chatml", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -499,6 +509,7 @@ _register_template( replace_eos=True, ) + _register_template( name="chatml_de", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -510,12 +521,14 @@ _register_template( replace_eos=True, ) + _register_template( name="codegeex2", format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), force_system=True, ) + _register_template( name="cohere", format_user=StringFormatter( @@ -530,6 +543,7 @@ _register_template( force_system=True, ) + _register_template( name="cpm", format_user=StringFormatter(slots=["<用户>{{content}}"]), @@ -537,6 +551,7 @@ _register_template( force_system=True, ) + _register_template( name="dbrx", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -562,6 +577,7 @@ _register_template( replace_eos=True, ) + _register_template( name="deepseek", format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), @@ -569,6 +585,7 @@ _register_template( force_system=True, ) + _register_template( name="deepseekcoder", format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), @@ -584,6 +601,7 @@ _register_template( efficient_eos=True, ) + _register_template( name="default", format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]), @@ -591,12 +609,14 @@ _register_template( format_separator=EmptyFormatter(slots=["\n"]), ) + _register_template( name="empty", format_user=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]), ) + _register_template( name="falcon", format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), @@ -604,12 +624,14 @@ _register_template( efficient_eos=True, ) + _register_template( name="fewshot", format_separator=EmptyFormatter(slots=["\n\n"]), efficient_eos=True, ) + _register_template( name="gemma", format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), @@ -622,6 +644,7 @@ _register_template( force_system=True, ) + _register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": ""}, "\n<|Bot|>:"]), @@ -630,6 +653,7 @@ _register_template( efficient_eos=True, ) + _register_template( name="intern2", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -646,6 +670,7 @@ _register_template( efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id ) + _register_template( name="llama2", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), @@ -662,6 +687,7 @@ _register_template( ), ) + _register_template( name="llama2_zh", format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), @@ -669,6 +695,7 @@ _register_template( default_system="You are a helpful assistant. 你是一个乐于助人的助手。", ) + _register_template( name="llama3", format_user=StringFormatter( @@ -695,6 +722,7 @@ _register_template( replace_eos=True, ) + _register_template( name="mistral", format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]), @@ -702,6 +730,7 @@ _register_template( force_system=True, ) + _register_template( name="olmo", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), @@ -710,6 +739,7 @@ _register_template( force_system=True, ) + _register_template( name="openchat", format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), @@ -718,6 +748,7 @@ _register_template( force_system=True, ) + _register_template( name="orion", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), @@ -725,6 +756,7 @@ _register_template( force_system=True, ) + _register_template( name="phi", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), @@ -736,6 +768,7 @@ _register_template( replace_eos=True, ) + _register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -747,6 +780,7 @@ _register_template( replace_eos=True, ) + _register_template( name="solar", format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), @@ -754,6 +788,7 @@ _register_template( efficient_eos=True, ) + _register_template( name="starchat", format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), @@ -764,6 +799,7 @@ _register_template( force_system=True, ) + _register_template( name="vicuna", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), @@ -773,6 +809,7 @@ _register_template( ), ) + _register_template( name="xuanyuan", format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), @@ -783,11 +820,13 @@ _register_template( ), ) + _register_template( name="xverse", format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), ) + _register_template( name="yayi", format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), @@ -807,6 +846,7 @@ _register_template( stop_words=["<|End|>"], ) + _register_template( name="yi", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), @@ -815,6 +855,7 @@ _register_template( replace_eos=True, ) + _register_template( name="yuan", format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), @@ -823,6 +864,7 @@ _register_template( replace_eos=True, ) + _register_template( name="zephyr", format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), @@ -831,12 +873,14 @@ _register_template( default_system="You are a friendly chatbot who always responds in the style of a pirate", ) + _register_template( name="ziya", format_user=StringFormatter(slots=[":{{content}}\n:"]), format_separator=EmptyFormatter(slots=["\n"]), ) + _register_template( name="llava", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT: "]), diff --git a/src/llmtuner/train/sft/workflow.py b/src/llmtuner/train/sft/workflow.py index 205142e5..c5acb4bc 100644 --- a/src/llmtuner/train/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -43,7 +43,7 @@ def run_sft( data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention - label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id), + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, ) # Override the decoding parameters of Seq2SeqTrainer diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index e1999946..a8a2b8e9 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -19,6 +19,7 @@ from .sft import run_sft if TYPE_CHECKING: from transformers import TrainerCallback + logger = get_logger(__name__)