Former-commit-id: 11c1e1e1570d3712109dd4dce831674a98841bd5
This commit is contained in:
hiyouga 2023-11-07 19:42:01 +08:00
parent ed584b9f52
commit 100dc4c458
5 changed files with 21 additions and 17 deletions

View File

@ -100,11 +100,13 @@ def get_train_args(
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": if finetuning_args.stage in ["rm", "ppo"]:
raise ValueError("RM and PPO stages can only be performed with the LoRA method.") if finetuning_args.finetuning_type != "lora":
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None: if training_args.resume_from_checkpoint is not None:
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.") raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
if training_args.load_best_model_at_end:
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
if finetuning_args.stage == "ppo" and not training_args.do_train: if finetuning_args.stage == "ppo" and not training_args.do_train:
raise ValueError("PPO training does not support evaluation.") raise ValueError("PPO training does not support evaluation.")

View File

@ -33,6 +33,12 @@ def run_dpo(
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
) )
# Create reference model
ref_model = None
if not isinstance(model, PeftModel):
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
# Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)
@ -41,7 +47,7 @@ def run_dpo(
trainer = CustomDPOTrainer( trainer = CustomDPOTrainer(
beta=finetuning_args.dpo_beta, beta=finetuning_args.dpo_beta,
model=model, model=model,
ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None, ref_model=ref_model,
args=training_args, args=training_args,
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,

View File

@ -190,8 +190,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
if len(response_index) == 0: if len(response_index) == 0:
response_length = 1 # allow empty response response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1].item() + 2 # save the EOS token
else: else:
response_length = response_index[-1].item() + 1 response_length = response_index[-1].item() + 1
@ -221,7 +219,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = [] rewards = []
for i in range(values.size(0)): for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1].item() # use the score on the EOS token end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")

View File

@ -34,7 +34,7 @@ class PairwiseTrainer(Trainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
Note that the first element will be removed from the output tuple. Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
""" """
# Compute rewards # Compute rewards
@ -45,9 +45,6 @@ class PairwiseTrainer(Trainer):
# Split the inputs and rewards into two parts, chosen and rejected # Split the inputs and rewards into two parts, chosen and rejected
batch_size = inputs["input_ids"].size(0) // 2 batch_size = inputs["input_ids"].size(0) // 2
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:] chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
chosen_attn_mask, rejected_attn_mask = (
inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
)
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:] chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
chosen_scores, rejected_scores = [], [] chosen_scores, rejected_scores = [], []
@ -55,8 +52,8 @@ class PairwiseTrainer(Trainer):
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0 loss = 0
for i in range(batch_size): for i in range(batch_size):
chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1 chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1 rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero() check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
if len(check_divergence) == 0: if len(check_divergence) == 0:
@ -69,7 +66,7 @@ class PairwiseTrainer(Trainer):
assert div_index > 0 assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
if return_outputs: # use the score on the EOS token for inference if return_outputs: # use the score on the last token except pad token for inference
chosen_scores.append(chosen_rewards[i, chosen_length-1]) chosen_scores.append(chosen_rewards[i, chosen_length-1])
rejected_scores.append(rejected_rewards[i, rejected_length-1]) rejected_scores.append(rejected_rewards[i, rejected_length-1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
@ -95,7 +92,6 @@ class PairwiseTrainer(Trainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info(f"Saving prediction results to {output_prediction_file}")
chosen_scores, rejected_scores = predict_results.predictions chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:

View File

@ -28,6 +28,7 @@ def run_rm(
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
# Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
training_args = Seq2SeqTrainingArguments(**training_args_dict) training_args = Seq2SeqTrainingArguments(**training_args_dict)