Support safe ChatML template, fix qwen tok #351 #354

https://github.com/openai/openai-python/blob/main/chatml.md
Former-commit-id: f30fc3b0303c9fab17e4563d96f6a33f7189e10d
This commit is contained in:
hoshi-hiyouga 2023-08-05 00:00:23 +08:00 committed by GitHub
commit 3ca0c3be60
3 changed files with 306 additions and 149 deletions

View File

@ -30,10 +30,11 @@ class ChatModel:
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix prefix = prefix or self.source_prefix
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token) prompt, _ = self.template.get_prompt(
inputs = self.tokenizer([prompt], return_tensors="pt") tokenizer=self.tokenizer, query=query, resp="", history=history, prefix=prefix
inputs = inputs.to(self.model.device) )
prompt_length = len(inputs["input_ids"][0]) input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
do_sample = input_kwargs.pop("do_sample", None) do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None) temperature = input_kwargs.pop("temperature", None)
@ -45,7 +46,7 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict() gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict( gen_kwargs.update(dict(
input_ids=inputs["input_ids"], input_ids=input_ids,
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"], temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"], top_p=top_p or gen_kwargs["top_p"],

View File

@ -30,7 +30,7 @@ def preprocess_dataset(
yield query, response, history, prefix yield query, response, history, prefix
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>) # build grouped texts with format `X1 X2 X3 ...` (without <eos>)
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False) tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
@ -55,20 +55,17 @@ def preprocess_dataset(
for query, response, history, prefix in construct_example(examples): for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], [] input_ids, labels = [], []
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)): for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix):
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length - 1] target_ids = target_ids[:data_args.max_target_length]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break break
input_ids += source_ids + target_ids + [tokenizer.eos_token_id] input_ids += source_ids + target_ids
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] labels += [IGNORE_INDEX] * len(source_ids) + target_ids
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
@ -81,10 +78,7 @@ def preprocess_dataset(
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, prefix in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) source_ids, target_ids = template.get_prompt(tokenizer, query, response, history, prefix)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@ -101,11 +95,8 @@ def preprocess_dataset(
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>` # build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []} model_inputs = {"accept_ids": [], "reject_ids": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, prefix in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) source_ids, accept_ids = template.get_prompt(tokenizer, query, response[0], history, prefix)
source_ids, reject_ids = template.get_prompt(tokenizer, query, response[1], history, prefix)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]

View File

@ -1,82 +1,153 @@
from typing import Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from dataclasses import dataclass from dataclasses import dataclass
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@dataclass @dataclass
class Template: class Template:
prefix: str prefix: List[Union[str, Dict[str, str]]]
prompt: str prompt: List[Union[str, Dict[str, str]]]
sep: str sep: List[Union[str, Dict[str, str]]]
use_history: bool
stop_words: List[str] stop_words: List[str]
use_history: bool
def get_prompt( def get_prompt(
self, self,
tokenizer: "PreTrainedTokenizer",
query: str, query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "", prefix: Optional[str] = None
eos_token: Optional[str] = "</s>" ) -> Tuple[List[int], List[int]]:
) -> str:
r""" r"""
Returns a string containing prompt without response. Returns a single pair of token ids representing prompt and response respectively.
""" """
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix))) prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix)
encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids = prompt_ids + encoded_pairs[-1][0]
return prompt_ids, encoded_pairs[-1][1]
def get_dialog( def get_dialog(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
prefix, history = self._format(query=query, resp=resp, history=history, prefix=prefix)
encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history)
return encoded_pairs
def _format(
self, self,
query: str, query: str,
resp: str, resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "" prefix: Optional[str] = None
) -> List[Tuple[str, str]]: ) -> Tuple[List[Union[str, Dict[str, str]]], List[Tuple[str, str]]]:
r""" r"""
Returns a list containing prompt-response pairs. Aligns inputs to a special format.
""" """
result = self._format_example(query, history, prefix) prefix = [prefix] if prefix is not None else self.prefix # use prefix if provided
result[-1][-1] = resp prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix
return result
def _format_example(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else [] history = history if (history and self.use_history) else []
history = history + [(query, "")] history = history + [(query, resp)]
return [ return prefix, history
[(self.sep if i else prefix) + self.prompt.format(query=q), r]
for i, (q, r) in enumerate(history) def _encode(
] self,
tokenizer: "PreTrainedTokenizer",
prefix: List[Union[str, Dict[str, str]]],
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
"""
if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional
bos_token_id = [tokenizer.bos_token_id]
else:
bos_token_id = []
eos_token_id = [tokenizer.eos_token_id] # eos token is required
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix)
else:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id))
return encoded_pairs
def _convert_inputs_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
query: Optional[str] = ""
) -> List[int]:
r"""
Converts context to token ids.
"""
token_ids = []
for elem in context:
if isinstance(elem, str):
subelems = elem.split("{{query}}")
if len(subelems) > 1:
elem = subelems[0] + query + subelems[1]
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise NotImplementedError
return token_ids
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
def _format_example( def _encode(
self, self,
query: str, tokenizer: "PreTrainedTokenizer",
history: Optional[List[Tuple[str, str]]] = None, prefix: List[Union[str, Dict[str, str]]],
prefix: Optional[str] = "" history: List[Tuple[str, str]]
) -> List[Tuple[str, str]]: ) -> List[Tuple[List[int], List[int]]]:
prefix = prefix or self.prefix # use prefix if provided r"""
prefix = prefix if prefix.startswith("<<SYS>>") else "<<SYS>>\n{}\n<</SYS>>\n\n".format(prefix) Encodes formatted inputs to pairs of token ids.
history = history if (history and self.use_history) else [] """
history = history + [(query, "")] encoded_pairs = []
return [ assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
[(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r] for turn_idx, (query, resp) in enumerate(history):
for i, (q, r) in enumerate(history) if turn_idx == 0:
] prefix_ids = []
query = prefix[0] + query
else:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids))
return encoded_pairs
templates: Dict[str, Template] = {} templates: Dict[str, Template] = {}
def register_template( def register_template(
name: str, prefix: str, prompt: str, sep: str, use_history: bool, stop_words: List[str] name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
sep: List[Union[str, Dict[str, str]]],
stop_words: List[str],
use_history: bool
) -> None: ) -> None:
template_class = Llama2Template if name == "llama2" else Template template_class = Llama2Template if name == "llama2" else Template
templates[name] = template_class( templates[name] = template_class(
@ -99,11 +170,13 @@ Supports language model inference without histories.
""" """
register_template( register_template(
name="vanilla", name="vanilla",
prefix="", prefix=[],
prompt="{query}", prompt=[
sep="", "{{query}}"
use_history=False, ],
stop_words=[] sep=[],
stop_words=[],
use_history=False
) )
@ -112,12 +185,18 @@ Default template.
""" """
register_template( register_template(
name="default", name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. " prefix=[
"The assistant gives helpful, detailed, and polite answers to the user's questions.", "A chat between a curious user and an artificial intelligence assistant. "
prompt="Human: {query}\nAssistant: ", "The assistant gives helpful, detailed, and polite answers to the user's questions."
sep="\n", ],
use_history=True, prompt=[
stop_words=[] "Human: {{query}}\nAssistant: "
],
sep=[
"\n"
],
stop_words=[],
use_history=True
) )
@ -128,18 +207,24 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
""" """
register_template( register_template(
name="llama2", name="llama2",
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. " prefix=[
"<<SYS>>\nYou are a helpful, respectful and honest assistant. "
"Always answer as helpfully as possible, while being safe. " "Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, " "Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. " "racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n" "Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, " "If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. " "explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n", "If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
prompt="[INST] {query} [/INST] ", ],
sep="<s>", prompt=[
use_history=True, "[INST] {{query}} [/INST] "
stop_words=[] ],
sep=[
{"token": "<s>"}
],
stop_words=[],
use_history=True
) )
@ -149,12 +234,18 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
""" """
register_template( register_template(
name="alpaca", name="alpaca",
prefix="Below is an instruction that describes a task. " prefix=[
"Write a response that appropriately completes the request.", "Below is an instruction that describes a task. "
prompt="### Instruction:\n{query}\n\n### Response:\n", "Write a response that appropriately completes the request."
sep="\n\n", ],
use_history=True, prompt=[
stop_words=[] "### Instruction:\n{{query}}\n\n### Response:\n"
],
sep=[
"\n\n"
],
stop_words=[],
use_history=True
) )
@ -164,12 +255,16 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
""" """
register_template( register_template(
name="vicuna", name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. " prefix=[
"The assistant gives helpful, detailed, and polite answers to the user's questions.", "A chat between a curious user and an artificial intelligence assistant. "
prompt="USER: {query} ASSISTANT: ", "The assistant gives helpful, detailed, and polite answers to the user's questions."
sep="", ],
use_history=True, prompt=[
stop_words=[] "USER: {{query}} ASSISTANT: "
],
sep=[],
stop_words=[],
use_history=True
) )
@ -178,11 +273,15 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
""" """
register_template( register_template(
name="belle", name="belle",
prefix="", prefix=[],
prompt="Human: {query}\n\nBelle: ", prompt=[
sep="\n\n", "Human: {{query}}\n\nBelle: "
use_history=True, ],
stop_words=[] sep=[
"\n\n"
],
stop_words=[],
use_history=True
) )
@ -191,11 +290,15 @@ Supports: https://github.com/CVI-SZU/Linly
""" """
register_template( register_template(
name="linly", name="linly",
prefix="", prefix=[],
prompt="User: {query}\nBot: ", prompt=[
sep="\n", "User: {{query}}\nBot: "
use_history=True, ],
stop_words=[] sep=[
"\n"
],
stop_words=[],
use_history=True
) )
@ -204,11 +307,15 @@ Supports: https://github.com/Neutralzz/BiLLa
""" """
register_template( register_template(
name="billa", name="billa",
prefix="", prefix=[],
prompt="Human: {query}\nAssistant: ", prompt=[
sep="\n", "Human: {{query}}\nAssistant: "
use_history=True, ],
stop_words=[] sep=[
"\n"
],
stop_words=[],
use_history=True
) )
@ -217,11 +324,18 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
""" """
register_template( register_template(
name="ziya", name="ziya",
prefix="", prefix=[],
prompt="<human>:{query}\n<bot>:", prompt=[
sep="\n", {"token": "<human>"},
use_history=True, ":{{query}}\n",
stop_words=[] {"token": "<bot>"},
":"
],
sep=[
"\n"
],
stop_words=[],
use_history=True
) )
@ -230,12 +344,18 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
""" """
register_template( register_template(
name="aquila", name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. " prefix=[
"The assistant gives helpful, detailed, and polite answers to the human's questions.", "A chat between a curious human and an artificial intelligence assistant. "
prompt="Human: {query}###Assistant: ", "The assistant gives helpful, detailed, and polite answers to the human's questions."
sep="###", ],
use_history=True, prompt=[
stop_words=[] "Human: {{query}}###Assistant: "
],
sep=[
"###"
],
stop_words=[],
use_history=True
) )
@ -244,11 +364,23 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
""" """
register_template( register_template(
name="intern", name="intern",
prefix="", prefix=[],
prompt="<|User|>:{query}<eoh>\n<|Bot|>:", prompt=[
sep="<eoa>\n", {"token": "<|User|>"},
use_history=True, ":{{query}}",
stop_words=["<eoa>"] {"token": "<eoh>"},
"\n",
{"token": "<|Bot|>"},
":"
],
sep=[
{"token": "<eoa>"},
"\n"
],
stop_words=[
"<eoa>"
],
use_history=True
) )
@ -257,11 +389,15 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
""" """
register_template( register_template(
name="baichuan", name="baichuan",
prefix="", prefix=[],
prompt="<reserved_102>{query}<reserved_103>", prompt=[
sep="", {"token": "<reserved_102>"},
use_history=True, "{{query}}",
stop_words=[] {"token": "<reserved_103>"}
],
sep=[],
stop_words=[],
use_history=True
) )
@ -271,11 +407,25 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
""" """
register_template( register_template(
name="starchat", name="starchat",
prefix="<|system|>\n", prefix=[
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n", {"token": "<|system|>"},
sep="<|end|>\n", "\n"
use_history=True, ],
stop_words=["<|end|>"] prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
sep=[
{"token": "<|end|>"},
"\n"
],
stop_words=[
"<|end|>"
],
use_history=True
) )
@ -284,9 +434,24 @@ Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
""" """
register_template( register_template(
name="chatml", name="chatml",
prefix="<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", prefix=[
prompt="<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n", {"token": "<|im_start|>"},
sep="<|im_end|>\n", "system\nYou are a helpful assistant."
use_history=True, ],
stop_words=["<|im_end|>"] prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
],
sep=[
{"token": "<|im_end|>"},
"\n"
],
stop_words=[
"<|im_end|>"
],
use_history=True
) )