fix baichuan templates

This commit is contained in:
hiyouga
2023-09-07 18:54:14 +08:00
parent 0531886e1f
commit 85b1f6632a
9 changed files with 53 additions and 87 deletions

View File

@@ -63,7 +63,9 @@ def preprocess_dataset(
for query, response, history, system in construct_example(examples):
input_ids, labels = [], []
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length:
@@ -72,8 +74,17 @@ def preprocess_dataset(
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break
if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += [IGNORE_INDEX] * len(source_ids) + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))