fix lora target

This commit is contained in:
hiyouga
2023-09-09 17:04:45 +08:00
parent bca1a247bc
commit a51b7c98ac
7 changed files with 63 additions and 43 deletions

View File

@@ -74,7 +74,7 @@ def preprocess_dataset(
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break
if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models
if turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
@@ -104,6 +104,9 @@ def preprocess_dataset(
if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length]
if template.efficient_eos:
target_ids += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids)
@@ -124,6 +127,10 @@ def preprocess_dataset(
if len(rejected_ids) > data_args.max_target_length:
rejected_ids = rejected_ids[:data_args.max_target_length]
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids)