Merge pull request #5913 from hiyouga/hiyouga/dev_metrics

[train] support gather DPO metrics, fix return output

Former-commit-id: 344ff76d26a42c859f31cd03765b1b613ffe6bfa
This commit is contained in:
hoshi-hiyouga 2024-11-02 21:13:43 +08:00 committed by GitHub
commit 7dbb338df7
9 changed files with 146 additions and 62 deletions

View File

@ -115,7 +115,6 @@ def _process_request(
else: else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
images = None if len(images) == 0 else images
tool_list = request.tools tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
@ -125,7 +124,7 @@ def _process_request(
else: else:
tools = None tools = None
return input_messages, system, tools, images return input_messages, system, tools, images or None
def _create_stream_chat_completion_chunk( def _create_stream_chat_completion_chunk(

View File

@ -287,13 +287,13 @@ class LogCallback(TrainerCallback):
logs = dict( logs = dict(
current_steps=self.cur_steps, current_steps=self.cur_steps,
total_steps=self.max_steps, total_steps=self.max_steps,
loss=state.log_history[-1].get("loss", None), loss=state.log_history[-1].get("loss"),
eval_loss=state.log_history[-1].get("eval_loss", None), eval_loss=state.log_history[-1].get("eval_loss"),
predict_loss=state.log_history[-1].get("predict_loss", None), predict_loss=state.log_history[-1].get("predict_loss"),
reward=state.log_history[-1].get("reward", None), reward=state.log_history[-1].get("reward"),
accuracy=state.log_history[-1].get("rewards/accuracies", None), accuracy=state.log_history[-1].get("rewards/accuracies"),
learning_rate=state.log_history[-1].get("learning_rate", None), lr=state.log_history[-1].get("learning_rate"),
epoch=state.log_history[-1].get("epoch", None), epoch=state.log_history[-1].get("epoch"),
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time, elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time, remaining_time=self.remaining_time,
@ -304,16 +304,17 @@ class LogCallback(TrainerCallback):
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]: if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory() vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2) logs["vram_allocated"] = round(vram_allocated / (1024**3), 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2) logs["vram_reserved"] = round(vram_reserved / (1024**3), 2)
logs = {k: v for k, v in logs.items() if v is not None} logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): if self.webui_mode and all(key in logs for key in ("loss", "lr", "epoch")):
logger.info_rank0( log_str = f"'loss': {logs['loss']:.4f}, 'learning_rate': {logs['lr']:2.4e}, 'epoch': {logs['epoch']:.2f}"
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( for extra_key in ("reward", "accuracy", "throughput"):
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A") if logs.get(extra_key):
) log_str += f", '{extra_key}': {logs[extra_key]:.2f}"
)
logger.info_rank0("{" + log_str + "}")
if self.thread_pool is not None: if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)

View File

@ -250,20 +250,18 @@ class CustomDPOTrainer(DPOTrainer):
if self.ftx_gamma > 1e-6: if self.ftx_gamma > 1e-6:
losses += self.ftx_gamma * sft_loss losses += self.ftx_gamma * sft_loss
reward_accuracies = (chosen_rewards > rejected_rewards).float()
prefix = "eval_" if train_eval == "eval" else "" prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu() metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().item()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu() metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().item()
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu() metrics[f"{prefix}rewards/accuracies"] = (chosen_rewards > rejected_rewards).float().mean().item()
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu() metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().item()
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu() metrics[f"{prefix}logps/rejected"] = policy_chosen_logps.mean().item()
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu() metrics[f"{prefix}logps/chosen"] = policy_rejected_logps.mean().item()
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu() metrics[f"{prefix}logits/rejected"] = policy_chosen_logits.mean().item()
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu() metrics[f"{prefix}logits/chosen"] = policy_rejected_logits.mean().item()
if self.loss_type == "orpo": if self.loss_type == "orpo":
metrics[f"{prefix}sft_loss"] = sft_loss.detach().mean().cpu() metrics[f"{prefix}sft_loss"] = sft_loss.mean().item()
metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).detach().mean().cpu() metrics[f"{prefix}odds_ratio_loss"] = ((losses - sft_loss) / self.beta).mean().item()
return losses.mean(), metrics return losses.mean(), metrics
@ -275,6 +273,36 @@ class CustomDPOTrainer(DPOTrainer):
""" """
loss = super().compute_loss(model, inputs, return_outputs) loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss return loss
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).mean().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 10: # pad to for all reduce
for i in range(10 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "mean").tolist()
for key, metric in zip(key_list, metric_list): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)

View File

@ -131,7 +131,7 @@ class CustomKTOTrainer(KTOTrainer):
@override @override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
r""" r"""
Runs forward pass and computes the log probabilities. Runs forward pass and computes the log probabilities.
""" """
@ -151,23 +151,25 @@ class CustomKTOTrainer(KTOTrainer):
logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32) logits = model(**model_inputs, return_dict=True, use_cache=False).logits.to(torch.float32)
logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"]) logps, valid_length = get_batch_logps(logits=logits, labels=batch[f"{prefix}labels"])
return logps, logps / valid_length return logits, logps, logps / valid_length
@override @override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
target_logps, target_logps_avg = self.forward(model, batch) target_logits, target_logps, target_logps_avg = self.forward(model, batch)
with torch.no_grad(): with torch.no_grad():
kl_logps, _ = self.forward(model, batch, prefix="kl_") _, kl_logps, _ = self.forward(model, batch, prefix="kl_")
if len(target_logps) != len(batch["kto_tags"]): if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.") raise ValueError("Mismatched shape of inputs and labels.")
chosen_logits = target_logits[batch["kto_tags"]]
chosen_logps = target_logps[batch["kto_tags"]] chosen_logps = target_logps[batch["kto_tags"]]
rejected_logits = target_logits[~batch["kto_tags"]]
rejected_logps = target_logps[~batch["kto_tags"]] rejected_logps = target_logps[~batch["kto_tags"]]
chosen_logps_avg = target_logps_avg[batch["kto_tags"]] chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps, chosen_logps_avg
@override @override
def compute_reference_log_probs( def compute_reference_log_probs(
@ -184,7 +186,7 @@ class CustomKTOTrainer(KTOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
reference_chosen_logps, reference_rejected_logps, reference_kl_logps, _ = self.concatenated_forward( reference_chosen_logps, reference_rejected_logps, _, _, reference_kl_logps, _ = self.concatenated_forward(
ref_model, batch ref_model, batch
) )
@ -200,9 +202,14 @@ class CustomKTOTrainer(KTOTrainer):
Computes the DPO loss and other metrics for the given batch of inputs for train or test. Computes the DPO loss and other metrics for the given batch of inputs for train or test.
""" """
metrics = {} metrics = {}
policy_chosen_logps, policy_rejected_logps, policy_kl_logps, policy_chosen_logps_avg = ( (
self.concatenated_forward(model, batch) policy_chosen_logps,
) policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_kl_logps,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)
reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs( reference_chosen_logps, reference_rejected_logps, reference_kl_logps = self.compute_reference_log_probs(
model, batch model, batch
) )
@ -220,24 +227,21 @@ class CustomKTOTrainer(KTOTrainer):
sft_loss = -policy_chosen_logps_avg sft_loss = -policy_chosen_logps_avg
losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"]) losses += self.ftx_gamma * sft_loss.nanmean() / len(policy_chosen_logps) * len(batch["labels"])
num_chosen = torch.Tensor([len(chosen_rewards)]).to(self.accelerator.device) num_chosen = len(chosen_rewards)
num_rejected = torch.Tensor([len(rejected_rewards)]).to(self.accelerator.device) num_rejected = len(rejected_rewards)
if num_chosen > 0:
metrics["rewards/chosen_sum"] = chosen_rewards.nansum().item()
metrics["logps/chosen_sum"] = policy_chosen_logps.nansum().item()
metrics["logits/chosen_sum"] = policy_chosen_logits.nansum().item()
metrics["count/chosen"] = float(num_chosen)
all_num_chosen = self.accelerator.gather(num_chosen).sum().item() if num_rejected > 0:
all_num_rejected = self.accelerator.gather(num_rejected).sum().item() metrics["rewards/rejected_sum"] = rejected_rewards.nansum().item()
metrics["logps/rejected_sum"] = policy_rejected_logps.nansum().item()
if all_num_chosen > 0: metrics["logits/rejected_sum"] = policy_rejected_logits.nansum().item()
metrics["rewards/chosen_sum"] = self.accelerator.gather(chosen_rewards.nansum()).nansum().item() metrics["count/rejected"] = float(num_rejected)
metrics["logps/chosen_sum"] = self.accelerator.gather(policy_chosen_logps.nansum()).nansum().item()
metrics["count/chosen"] = all_num_chosen
if all_num_rejected > 0:
metrics["rewards/rejected_sum"] = self.accelerator.gather(rejected_rewards.nansum()).nansum().item()
metrics["logps/rejected_sum"] = self.accelerator.gather(policy_rejected_logps.nansum()).nansum().item()
metrics["count/rejected"] = all_num_rejected
metrics["kl"] = kl.item() metrics["kl"] = kl.item()
return losses, metrics return losses, metrics
@override @override
@ -248,6 +252,48 @@ class CustomKTOTrainer(KTOTrainer):
""" """
loss = super().compute_loss(model, inputs, return_outputs) loss = super().compute_loss(model, inputs, return_outputs)
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
loss /= self.args.gradient_accumulation_steps if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss return loss
@override
def log(self, logs: Dict[str, float]) -> None:
r"""
Log `logs` on the various objects watching training, including stored metrics.
"""
# logs either has "loss" or "eval_loss"
train_eval = "train" if "loss" in logs else "eval"
prefix = "eval_" if train_eval == "eval" else ""
# Add averaged stored metrics to logs
key_list, metric_list = [], []
for key, metrics in self._stored_metrics[train_eval].items():
key_list.append(key)
metric_list.append(torch.tensor(metrics, dtype=torch.float).to(self.accelerator.device).sum().item())
del self._stored_metrics[train_eval]
if len(metric_list) < 9: # pad to for all reduce
for i in range(9 - len(metric_list)):
key_list.append(f"dummy_{i}")
metric_list.append(0.0)
metric_list = torch.tensor(metric_list, dtype=torch.float).to(self.accelerator.device)
metric_list = self.accelerator.reduce(metric_list, "sum").tolist()
metric_dict: Dict[str, float] = dict(zip(key_list, metric_list))
for split in ["chosen", "rejected"]: # accumulate average metrics from sums and lengths
if f"count/{split}" in metric_dict:
for key in ("rewards", "logps", "logits"):
logs[f"{prefix}{key}/{split}"] = metric_dict[f"{key}/{split}_sum"] / metric_dict[f"count/{split}"]
del metric_dict[f"{key}/{split}_sum"]
del metric_dict[f"count/{split}"]
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs: # calculate reward margin
logs[f"{prefix}rewards/margins"] = logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
for key, metric in metric_dict.items(): # add remaining items
if not key.startswith("dummy_"):
logs[key] = metric
return Trainer.log(self, logs)

View File

@ -81,7 +81,7 @@ def run_kto(
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss: if trainer.is_world_process_zero() and finetuning_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "train/rewards/chosen"]) plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "rewards/chosen"])
# Evaluation # Evaluation
if training_args.do_eval: if training_args.do_eval:

View File

@ -74,6 +74,10 @@ class CustomTrainer(Trainer):
""" """
loss = super().compute_loss(model, inputs, return_outputs, **kwargs) loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss # other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss return loss

View File

@ -87,7 +87,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
""" """
loss = super().compute_loss(model, inputs, return_outputs, **kwargs) loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss # other model should not scale the loss
if return_outputs:
return (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
else:
return loss / self.args.gradient_accumulation_steps
return loss return loss

View File

@ -144,8 +144,8 @@ class WebChatModel(ChatModel):
messages, messages,
system, system,
tools, tools,
images=[image], images=[image] if image else None,
videos=[video], videos=[video] if video else None,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,

View File

@ -98,6 +98,7 @@ class Runner:
def _finalize(self, lang: str, finish_info: str) -> str: def _finalize(self, lang: str, finish_info: str) -> str:
finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info finish_info = ALERTS["info_aborted"][lang] if self.aborted else finish_info
gr.Info(finish_info)
self.trainer = None self.trainer = None
self.aborted = False self.aborted = False
self.running = False self.running = False
@ -357,6 +358,7 @@ class Runner:
progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval")) progress_bar = self.manager.get_elem_by_id("{}.progress_bar".format("train" if self.do_train else "eval"))
loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None loss_viewer = self.manager.get_elem_by_id("train.loss_viewer") if self.do_train else None
running_log = ""
while self.trainer is not None: while self.trainer is not None:
if self.aborted: if self.aborted:
yield { yield {
@ -392,7 +394,7 @@ class Runner:
finish_info = ALERTS["err_failed"][lang] finish_info = ALERTS["err_failed"][lang]
return_dict = { return_dict = {
output_box: self._finalize(lang, finish_info), output_box: self._finalize(lang, finish_info) + "\n\n" + running_log,
progress_bar: gr.Slider(visible=False), progress_bar: gr.Slider(visible=False),
} }
yield return_dict yield return_dict