fix dpo metrics

Former-commit-id: 4270f7dfb9a12471c91f6c03dce7ca6fd88566e1
This commit is contained in:
hiyouga 2024-11-02 19:22:11 +08:00
parent e7b11e4fdb
commit c2766af6f4
7 changed files with 143 additions and 58 deletions

View File

@ -287,13 +287,13 @@ class LogCallback(TrainerCallback):
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,
total_steps=self.max_steps, total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None), loss=state.log_history[-1].get("loss"),
eval_loss=state.log_history[-1].get("eval_loss", None), eval_loss=state.log_history[-1].get("eval_loss"),
predict_loss=state.log_history[-1].get("predict_loss", None), predict_loss=state.log_history[-1].get("predict_loss"),
reward=state.log_history[-1].get("reward", None), reward=state.log_history[-1].get("reward"),
accuracy=state.log_history[-1].get("rewards/accuracies", None), accuracy=state.log_history[-1].get("rewards/accuracies"),
learning_rate=state.log_history[-1].get("learning_rate", None), lr=state.log_history[-1].get("learning_rate"),
epoch=state.log_history[-1].get("epoch", None), epoch=state.log_history[-1].get("epoch"),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
@ -304,16 +304,17 @@ class LogCallback(TrainerCallback):
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]: if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory() vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2) logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2) logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
logs = {k: v for k, v in logs.items() if v is not None} logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
logger.info_rank0( log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( for extra_key in ("reward", "accuracy", "throughput"):
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A") if logs.get(extra_key):
) log_str += f", '{extra_key}': {logs[extra_key]:.2f}"
)
logger.info_rank0("{" + log_str + "}")
if self.thread_pool is not None: if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)

View File

@ -250,20 +250,18 @@ class CustomDPOTrainer(DPOTrainer):
if self.ftx_gamma > 1e-6: if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss losses += self.ftx_gamma * sft_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item()
if self.loss_type == "orpo": if self.loss_type == "orpo":
metrics[f"{prefix}sft_loss"] = sft_loss.detach().mean().cpu() metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu() metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
return losses.mean(), metrics return losses.mean(), metrics
@ -275,6 +273,36 @@ class CustomDPOTrainer(DPOTrainer):
""" """
loss = super().compute_loss(model, inputs, return_outputs) loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss return loss
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).mean().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 10: # pad to for all reduce
for i in range(10 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
for key, metric in zip(key_list, metric_list): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)

View File

@ -131,7 +131,7 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""
Runs forward pass and computes the log probabilities. Runs forward pass and computes the log probabilities.
""" """
@ -151,23 +151,25 @@ class CustomKTOTrainer(KTOTrainer):
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"]) logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logps, logps / valid_length return logits, logps, logps / valid_length
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logps, target_logps_avg = self.forward(model, batch) target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad(): with torch.no_grad():
kl_logps, _ = self.forward(model, batch, prefix="kl_") _, kl_logps, _ = self.forward(model, batch, prefix="kl_")
if len(target_logps) != len(batch["kto_tags"]): if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.") raise ValueError("Mismatched shape of inputs and labels.")
chosen_logits = target_logits[batch["kto_tags"]]
chosen_logps = target_logps[batch["kto_tags"]] chosen_logps = target_logps[batch["kto_tags"]]
rejected_logits = target_logits[~batch["kto_tags"]]
rejected_logps = target_logps[~batch["kto_tags"]] rejected_logps = target_logps[~batch["kto_tags"]]
chosen_logps_avg = target_logps_avg[batch["kto_tags"]] chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps, chosen_logps_avg
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
@ -184,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward( reference_chosen_logps, reference_rejected_logps, _, _, reference_kl_logps, _ = self.concatenated_forward(
ref_model, batch ref_model, batch
) )
@ -200,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.
""" """
metrics = {} metrics = {}
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = ( (
self.concatenated_forward(model, batch) policy_chosen_logps,
) policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_kl_logps,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs( reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch model, batch
) )
@ -220,24 +227,21 @@ class CustomKTOTrainer(KTOTrainer):
sft_loss = -policy_chosen_logps_avg sft_loss = -policy_chosen_logps_avg
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"]) losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_chosen = len(chosen_rewards)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) num_rejected = len(rejected_rewards)
if num_chosen > 0:
metrics["rewards/chosen_sum"] = chosen_rewards.nansum().item()
metrics["logps/chosen_sum"] = policy_chosen_logps.nansum().item()
metrics["logits/chosen_sum"] = policy_chosen_logits.nansum().item()
metrics["count/chosen"] = float(num_chosen)
all_num_chosen = self.accelerator.gather(num_chosen).sum().item() if num_rejected > 0:
all_num_rejected = self.accelerator.gather(num_rejected).sum().item() metrics["rewards/rejected_sum"] = rejected_rewards.nansum().item()
metrics["logps/rejected_sum"] = policy_rejected_logps.nansum().item()
if all_num_chosen > 0: metrics["logits/rejected_sum"] = policy_rejected_logits.nansum().item()
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() metrics["count/rejected"] = float(num_rejected)
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["count/chosen"] = all_num_chosen
if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["count/rejected"] = all_num_rejected
metrics["kl"] = kl.item() metrics["kl"] = kl.item()
return losses, metrics return losses, metrics
@override @override
@ -248,6 +252,48 @@ class CustomKTOTrainer(KTOTrainer):
""" """
loss = super().compute_loss(model, inputs, return_outputs) loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss return loss
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else ""
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).sum().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 9: # pad to for all reduce
for i in range(9 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"):
logs[f"{prefix}{key}/{split}"] = metric_dict[f"{key}/{split}_sum"] / metric_dict[f"count/{split}"]
del metric_dict[f"{key}/{split}_sum"]
del metric_dict[f"count/{split}"]
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: # calculate reward margin
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
for key, metric in metric_dict.items(): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)

View File

@ -81,7 +81,7 @@ def run_kto(
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@ -74,6 +74,10 @@ class CustomTrainer(Trainer):
""" """
loss = super().compute_loss(model, inputs, return_outputs, **kwargs) loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss # other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss return loss

View File

@ -87,7 +87,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
""" """
loss = super().compute_loss(model, inputs, return_outputs, **kwargs) loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss # other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss return loss

View File

@ -98,6 +98,7 @@ class Runner:
def _finalize(self, lang: str, finish_info: str) -> str: def _finalize(self, lang: str, finish_info: str) -> str:
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info)
self.trainer = None self.trainer = None
self.aborted = False self.aborted = False
self.running = False self.running = False
@ -357,6 +358,7 @@ class Runner:
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval")) progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
running_log = ""
while self.trainer is not None: while self.trainer is not None:
if self.aborted: if self.aborted:
yield { yield {
@ -392,7 +394,7 @@ class Runner:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
return_dict = { return_dict = {
output_box: self._finalize(lang, finish_info), output_box: self._finalize(lang, finish_info) + "\n\n" + running_log,
progress_bar: gr.Slider(visible=False), progress_bar: gr.Slider(visible=False),
} }
yield return_dict yield return_dict