mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
ed584b9f52
commit
100dc4c458
@ -100,11 +100,13 @@ def get_train_args(
|
|||||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
|
if finetuning_args.stage in ["rm", "ppo"]:
|
||||||
|
if finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||||
|
if training_args.resume_from_checkpoint is not None:
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
|
|
||||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||||
|
if training_args.load_best_model_at_end:
|
||||||
|
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||||
|
|
||||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||||
raise ValueError("PPO training does not support evaluation.")
|
raise ValueError("PPO training does not support evaluation.")
|
||||||
|
@ -33,6 +33,12 @@ def run_dpo(
|
|||||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Create reference model
|
||||||
|
ref_model = None
|
||||||
|
if not isinstance(model, PeftModel):
|
||||||
|
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||||
|
|
||||||
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
@ -41,7 +47,7 @@ def run_dpo(
|
|||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
beta=finetuning_args.dpo_beta,
|
beta=finetuning_args.dpo_beta,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
|
ref_model=ref_model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
|
@ -190,8 +190,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
if len(response_index) == 0:
|
if len(response_index) == 0:
|
||||||
response_length = 1 # allow empty response
|
response_length = 1 # allow empty response
|
||||||
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
|
||||||
response_length = response_index[-1].item() + 2 # save the EOS token
|
|
||||||
else:
|
else:
|
||||||
response_length = response_index[-1].item() + 1
|
response_length = response_index[-1].item() + 1
|
||||||
|
|
||||||
@ -221,7 +219,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
for i in range(values.size(0)):
|
for i in range(values.size(0)):
|
||||||
end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token
|
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
|
||||||
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
@ -45,9 +45,6 @@ class PairwiseTrainer(Trainer):
|
|||||||
# Split the inputs and rewards into two parts, chosen and rejected
|
# Split the inputs and rewards into two parts, chosen and rejected
|
||||||
batch_size = inputs["input_ids"].size(0) // 2
|
batch_size = inputs["input_ids"].size(0) // 2
|
||||||
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
||||||
chosen_attn_mask, rejected_attn_mask = (
|
|
||||||
inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
|
|
||||||
)
|
|
||||||
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
||||||
chosen_scores, rejected_scores = [], []
|
chosen_scores, rejected_scores = [], []
|
||||||
|
|
||||||
@ -55,8 +52,8 @@ class PairwiseTrainer(Trainer):
|
|||||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
|
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
|
||||||
loss = 0
|
loss = 0
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
|
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||||
rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
|
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||||
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
||||||
|
|
||||||
if len(check_divergence) == 0:
|
if len(check_divergence) == 0:
|
||||||
@ -69,7 +66,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
assert div_index > 0
|
assert div_index > 0
|
||||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||||
if return_outputs: # use the score on the EOS token for inference
|
if return_outputs: # use the score on the last token except pad token for inference
|
||||||
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
||||||
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
||||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||||
@ -95,7 +92,6 @@ class PairwiseTrainer(Trainer):
|
|||||||
|
|
||||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||||
|
|
||||||
chosen_scores, rejected_scores = predict_results.predictions
|
chosen_scores, rejected_scores = predict_results.predictions
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||||
|
@ -28,6 +28,7 @@ def run_rm(
|
|||||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
||||||
|
|
||||||
|
# Update arguments
|
||||||
training_args_dict = training_args.to_dict()
|
training_args_dict = training_args.to_dict()
|
||||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user