mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
fix lora target
This commit is contained in:
@@ -74,7 +74,7 @@ def preprocess_dataset(
|
||||
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
||||
break
|
||||
|
||||
if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models
|
||||
if turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
@@ -104,6 +104,9 @@ def preprocess_dataset(
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
|
||||
if template.efficient_eos:
|
||||
target_ids += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(source_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(source_ids))
|
||||
model_inputs["labels"].append(target_ids)
|
||||
@@ -124,6 +127,10 @@ def preprocess_dataset(
|
||||
if len(rejected_ids) > data_args.max_target_length:
|
||||
rejected_ids = rejected_ids[:data_args.max_target_length]
|
||||
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
|
||||
Reference in New Issue
Block a user