fix encode

Former-commit-id: 8172ad1b5e3fa0b224d761ce6069d0db4397da2d
This commit is contained in:
hiyouga 2023-08-04 23:27:55 +08:00
parent ea045b0e5b
commit dbb284b5a2
2 changed files with 19 additions and 17 deletions

View File

@ -55,7 +55,7 @@ def preprocess_dataset(
for query, response, history, prefix in construct_example(examples): for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], [] input_ids, labels = [], []
for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): # TODO: fix bos
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token if len(target_ids) > data_args.max_target_length - 1: # eos token

View File

@ -29,7 +29,7 @@ class Template:
encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history) encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history)
prompt_ids = [] prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]: for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids prompt_ids = prompt_ids + query_ids + resp_ids + [tokenizer.eos_token_id]
prompt_ids = prompt_ids + encoded_pairs[-1][0] prompt_ids = prompt_ids + encoded_pairs[-1][0]
return prompt_ids, encoded_pairs[-1][1] return prompt_ids, encoded_pairs[-1][1]
@ -96,7 +96,9 @@ class Template:
token_ids = [] token_ids = []
for elem in context: for elem in context:
if isinstance(elem, str): if isinstance(elem, str):
elem = elem.format(query=query) subelems = elem.split("{{query}}")
if len(subelems) > 1:
elem = subelems[0] + query + subelems[1]
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False) token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict): elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
@ -165,7 +167,7 @@ register_template(
name="vanilla", name="vanilla",
prefix=[], prefix=[],
prompt=[ prompt=[
"{query}" "{{query}}"
], ],
sep=[], sep=[],
stop_words=[], stop_words=[],
@ -183,7 +185,7 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the user's questions." "The assistant gives helpful, detailed, and polite answers to the user's questions."
], ],
prompt=[ prompt=[
"Human: {query}\nAssistant: " "Human: {{query}}\nAssistant: "
], ],
sep=[ sep=[
"\n" "\n"
@ -211,7 +213,7 @@ register_template(
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n" "If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
], ],
prompt=[ prompt=[
"[INST] {query} [/INST] " "[INST] {{query}} [/INST] "
], ],
sep=[ sep=[
{"token": "<s>"} {"token": "<s>"}
@ -232,7 +234,7 @@ register_template(
"Write a response that appropriately completes the request." "Write a response that appropriately completes the request."
], ],
prompt=[ prompt=[
"### Instruction:\n{query}\n\n### Response:\n" "### Instruction:\n{{query}}\n\n### Response:\n"
], ],
sep=[ sep=[
"\n\n" "\n\n"
@ -253,7 +255,7 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the user's questions." "The assistant gives helpful, detailed, and polite answers to the user's questions."
], ],
prompt=[ prompt=[
"USER: {query} ASSISTANT: " "USER: {{query}} ASSISTANT: "
], ],
sep=[], sep=[],
stop_words=[], stop_words=[],
@ -268,7 +270,7 @@ register_template(
name="belle", name="belle",
prefix=[], prefix=[],
prompt=[ prompt=[
"Human: {query}\n\nBelle: " "Human: {{query}}\n\nBelle: "
], ],
sep=[ sep=[
"\n\n" "\n\n"
@ -285,7 +287,7 @@ register_template(
name="linly", name="linly",
prefix=[], prefix=[],
prompt=[ prompt=[
"User: {query}\nBot: " "User: {{query}}\nBot: "
], ],
sep=[ sep=[
"\n" "\n"
@ -302,7 +304,7 @@ register_template(
name="billa", name="billa",
prefix=[], prefix=[],
prompt=[ prompt=[
"Human: {query}\nAssistant: " "Human: {{query}}\nAssistant: "
], ],
sep=[ sep=[
"\n" "\n"
@ -320,7 +322,7 @@ register_template(
prefix=[], prefix=[],
prompt=[ prompt=[
{"token": "<human>"}, {"token": "<human>"},
":{query}\n", ":{{query}}\n",
{"token": "<bot>"}, {"token": "<bot>"},
":" ":"
], ],
@ -342,7 +344,7 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the human's questions." "The assistant gives helpful, detailed, and polite answers to the human's questions."
], ],
prompt=[ prompt=[
"Human: {query}###Assistant: " "Human: {{query}}###Assistant: "
], ],
sep=[ sep=[
"###" "###"
@ -360,7 +362,7 @@ register_template(
prefix=[], prefix=[],
prompt=[ prompt=[
{"token": "<|User|>"}, {"token": "<|User|>"},
":{query}", ":{{query}}",
{"token": "<eoh>"}, {"token": "<eoh>"},
"\n", "\n",
{"token": "<|Bot|>"}, {"token": "<|Bot|>"},
@ -385,7 +387,7 @@ register_template(
prefix=[], prefix=[],
prompt=[ prompt=[
{"token": "<reserved_102>"}, {"token": "<reserved_102>"},
"{query}", "{{query}}",
{"token": "<reserved_103>"} {"token": "<reserved_103>"}
], ],
sep=[], sep=[],
@ -406,7 +408,7 @@ register_template(
], ],
prompt=[ prompt=[
{"token": "<|user|>"}, {"token": "<|user|>"},
"\n{query}", "\n{{query}}",
{"token": "<|end|>"}, {"token": "<|end|>"},
"\n", "\n",
{"token": "<|assistant|>"} {"token": "<|assistant|>"}
@ -433,7 +435,7 @@ register_template(
], ],
prompt=[ prompt=[
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},
"user\n{query}", "user\n{{query}}",
{"token": "<|im_end|>"}, {"token": "<|im_end|>"},
"\n", "\n",
{"token": "<|im_start|>"}, {"token": "<|im_start|>"},