modify some style

Former-commit-id: c27f7fbf62b4a00b5794bb20621a4060a82490b7
This commit is contained in:
BUAADreamer 2024-04-25 22:04:09 +08:00
parent a0be27fc9b
commit a9b0da597b
4 changed files with 50 additions and 8 deletions

View File

@ -324,10 +324,7 @@ def get_preprocess_and_print_func(
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer)
else: else:
preprocess_func = partial( preprocess_func = partial(
preprocess_unsupervised_dataset, preprocess_unsupervised_dataset, tokenizer=tokenizer, template=template, data_args=data_args
tokenizer=tokenizer,
template=template,
data_args=data_args,
) )
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer)

View File

@ -11,6 +11,7 @@ if TYPE_CHECKING:
from .formatter import SLOTS, Formatter from .formatter import SLOTS, Formatter
logger = get_logger(__name__) logger = get_logger(__name__)
@ -103,9 +104,7 @@ class Template:
return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len)
def _convert_elements_to_ids( def _convert_elements_to_ids(
self, self, tokenizer: "PreTrainedTokenizer", elements: List[Union[str, Dict[str, str]]]
tokenizer: "PreTrainedTokenizer",
elements: List[Union[str, Dict[str, str]]],
) -> List[int]: ) -> List[int]:
r""" r"""
Converts elements to token ids. Converts elements to token ids.
@ -391,6 +390,7 @@ _register_template(
), ),
) )
_register_template( _register_template(
name="aquila", name="aquila",
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
@ -403,6 +403,7 @@ _register_template(
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="atom", name="atom",
format_user=StringFormatter( format_user=StringFormatter(
@ -411,18 +412,21 @@ _register_template(
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
) )
_register_template( _register_template(
name="baichuan", name="baichuan",
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]), format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="baichuan2", name="baichuan2",
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]), format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="belle", name="belle",
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
@ -431,11 +435,13 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="bluelm", name="bluelm",
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
) )
_register_template( _register_template(
name="breeze", name="breeze",
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
@ -447,6 +453,7 @@ _register_template(
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="chatglm2", name="chatglm2",
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
@ -456,6 +463,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="chatglm3", name="chatglm3",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@ -470,6 +478,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="chatglm3_system", name="chatglm3_system",
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
@ -489,6 +498,7 @@ _register_template(
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="chatml", name="chatml",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -499,6 +509,7 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template( _register_template(
name="chatml_de", name="chatml_de",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -510,12 +521,14 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template( _register_template(
name="codegeex2", name="codegeex2",
format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]), format_system=StringFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}, "{{content}}"]),
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="cohere", name="cohere",
format_user=StringFormatter( format_user=StringFormatter(
@ -530,6 +543,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="cpm", name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]), format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
@ -537,6 +551,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="dbrx", name="dbrx",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -562,6 +577,7 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template( _register_template(
name="deepseek", name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@ -569,6 +585,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="deepseekcoder", name="deepseekcoder",
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
@ -584,6 +601,7 @@ _register_template(
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="default", name="default",
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]), format_user=StringFormatter(slots=["Human: {{content}}\nAssistant: "]),
@ -591,12 +609,14 @@ _register_template(
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
) )
_register_template( _register_template(
name="empty", name="empty",
format_user=StringFormatter(slots=["{{content}}"]), format_user=StringFormatter(slots=["{{content}}"]),
format_assistant=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]),
) )
_register_template( _register_template(
name="falcon", name="falcon",
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
@ -604,12 +624,14 @@ _register_template(
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="fewshot", name="fewshot",
format_separator=EmptyFormatter(slots=["\n\n"]), format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="gemma", name="gemma",
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]), format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
@ -622,6 +644,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="intern", name="intern",
format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]), format_user=StringFormatter(slots=["<|User|>:{{content}}", {"token": "<eoh>"}, "\n<|Bot|>:"]),
@ -630,6 +653,7 @@ _register_template(
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="intern2", name="intern2",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -646,6 +670,7 @@ _register_template(
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
) )
_register_template( _register_template(
name="llama2", name="llama2",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@ -662,6 +687,7 @@ _register_template(
), ),
) )
_register_template( _register_template(
name="llama2_zh", name="llama2_zh",
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
@ -669,6 +695,7 @@ _register_template(
default_system="You are a helpful assistant. 你是一个乐于助人的助手。", default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
) )
_register_template( _register_template(
name="llama3", name="llama3",
format_user=StringFormatter( format_user=StringFormatter(
@ -695,6 +722,7 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template( _register_template(
name="mistral", name="mistral",
format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]), format_user=StringFormatter(slots=[" [INST] {{content}} [/INST]"]),
@ -702,6 +730,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="olmo", name="olmo",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
@ -710,6 +739,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="openchat", name="openchat",
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
@ -718,6 +748,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="orion", name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
@ -725,6 +756,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="phi", name="phi",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
@ -736,6 +768,7 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template( _register_template(
name="qwen", name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -747,6 +780,7 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template( _register_template(
name="solar", name="solar",
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
@ -754,6 +788,7 @@ _register_template(
efficient_eos=True, efficient_eos=True,
) )
_register_template( _register_template(
name="starchat", name="starchat",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
@ -764,6 +799,7 @@ _register_template(
force_system=True, force_system=True,
) )
_register_template( _register_template(
name="vicuna", name="vicuna",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
@ -773,6 +809,7 @@ _register_template(
), ),
) )
_register_template( _register_template(
name="xuanyuan", name="xuanyuan",
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
@ -783,11 +820,13 @@ _register_template(
), ),
) )
_register_template( _register_template(
name="xverse", name="xverse",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
) )
_register_template( _register_template(
name="yayi", name="yayi",
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
@ -807,6 +846,7 @@ _register_template(
stop_words=["<|End|>"], stop_words=["<|End|>"],
) )
_register_template( _register_template(
name="yi", name="yi",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
@ -815,6 +855,7 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template( _register_template(
name="yuan", name="yuan",
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]), format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
@ -823,6 +864,7 @@ _register_template(
replace_eos=True, replace_eos=True,
) )
_register_template( _register_template(
name="zephyr", name="zephyr",
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
@ -831,12 +873,14 @@ _register_template(
default_system="You are a friendly chatbot who always responds in the style of a pirate", default_system="You are a friendly chatbot who always responds in the style of a pirate",
) )
_register_template( _register_template(
name="ziya", name="ziya",
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]), format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
format_separator=EmptyFormatter(slots=["\n"]), format_separator=EmptyFormatter(slots=["\n"]),
) )
_register_template( _register_template(
name="llava", name="llava",
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT: "]), format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT: "]),

View File

@ -43,7 +43,7 @@ def run_sft(
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=(IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id), label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
) )
# Override the decoding parameters of Seq2SeqTrainer # Override the decoding parameters of Seq2SeqTrainer

View File

@ -19,6 +19,7 @@ from .sft import run_sft
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainerCallback from transformers import TrainerCallback
logger = get_logger(__name__) logger = get_logger(__name__)