fix encode

Former-commit-id: 8172ad1b5e3fa0b224d761ce6069d0db4397da2d
This commit is contained in:
hiyouga 2023-08-04 23:27:55 +08:00
parent ea045b0e5b
commit dbb284b5a2
2 changed files with 19 additions and 17 deletions

View File

@ -55,7 +55,7 @@ def preprocess_dataset(
for query, response, history, prefix in construct_example(examples):
input_ids, labels = [], []
for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix):
for source_ids, target_ids in template.get_dialog(tokenizer, query, response, history, prefix): # TODO: fix bos
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token

View File

@ -29,7 +29,7 @@ class Template:
encoded_pairs = self._encode(tokenizer=tokenizer, prefix=prefix, history=history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids = prompt_ids + query_ids + resp_ids + [tokenizer.eos_token_id]
prompt_ids = prompt_ids + encoded_pairs[-1][0]
return prompt_ids, encoded_pairs[-1][1]
@ -96,7 +96,9 @@ class Template:
token_ids = []
for elem in context:
if isinstance(elem, str):
elem = elem.format(query=query)
subelems = elem.split("{{query}}")
if len(subelems) > 1:
elem = subelems[0] + query + subelems[1]
token_ids = token_ids + tokenizer.encode(elem, add_special_tokens=False)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
@ -165,7 +167,7 @@ register_template(
name="vanilla",
prefix=[],
prompt=[
"{query}"
"{{query}}"
],
sep=[],
stop_words=[],
@ -183,7 +185,7 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the user's questions."
],
prompt=[
"Human: {query}\nAssistant: "
"Human: {{query}}\nAssistant: "
],
sep=[
"\n"
@ -211,7 +213,7 @@ register_template(
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
],
prompt=[
"[INST] {query} [/INST] "
"[INST] {{query}} [/INST] "
],
sep=[
{"token": "<s>"}
@ -232,7 +234,7 @@ register_template(
"Write a response that appropriately completes the request."
],
prompt=[
"### Instruction:\n{query}\n\n### Response:\n"
"### Instruction:\n{{query}}\n\n### Response:\n"
],
sep=[
"\n\n"
@ -253,7 +255,7 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the user's questions."
],
prompt=[
"USER: {query} ASSISTANT: "
"USER: {{query}} ASSISTANT: "
],
sep=[],
stop_words=[],
@ -268,7 +270,7 @@ register_template(
name="belle",
prefix=[],
prompt=[
"Human: {query}\n\nBelle: "
"Human: {{query}}\n\nBelle: "
],
sep=[
"\n\n"
@ -285,7 +287,7 @@ register_template(
name="linly",
prefix=[],
prompt=[
"User: {query}\nBot: "
"User: {{query}}\nBot: "
],
sep=[
"\n"
@ -302,7 +304,7 @@ register_template(
name="billa",
prefix=[],
prompt=[
"Human: {query}\nAssistant: "
"Human: {{query}}\nAssistant: "
],
sep=[
"\n"
@ -320,7 +322,7 @@ register_template(
prefix=[],
prompt=[
{"token": "<human>"},
":{query}\n",
":{{query}}\n",
{"token": "<bot>"},
":"
],
@ -342,7 +344,7 @@ register_template(
"The assistant gives helpful, detailed, and polite answers to the human's questions."
],
prompt=[
"Human: {query}###Assistant: "
"Human: {{query}}###Assistant: "
],
sep=[
"###"
@ -360,7 +362,7 @@ register_template(
prefix=[],
prompt=[
{"token": "<|User|>"},
":{query}",
":{{query}}",
{"token": "<eoh>"},
"\n",
{"token": "<|Bot|>"},
@ -385,7 +387,7 @@ register_template(
prefix=[],
prompt=[
{"token": "<reserved_102>"},
"{query}",
"{{query}}",
{"token": "<reserved_103>"}
],
sep=[],
@ -406,7 +408,7 @@ register_template(
],
prompt=[
{"token": "<|user|>"},
"\n{query}",
"\n{{query}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
@ -433,7 +435,7 @@ register_template(
],
prompt=[
{"token": "<|im_start|>"},
"user\n{query}",
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},