From e75024fde3b8f96a706a8915dd890354449b1d5e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 14 Aug 2023 00:23:56 +0800 Subject: [PATCH] fix #480 Former-commit-id: 2f2fd55d8175eb3c6ce94bc821ab4e6331f79d8e --- src/llmtuner/tuner/dpo/trainer.py | 16 +++++++++------- src/llmtuner/webui/utils.py | 1 + 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/tuner/dpo/trainer.py index a94642c1..f0da9eaa 100644 --- a/src/llmtuner/tuner/dpo/trainer.py +++ b/src/llmtuner/tuner/dpo/trainer.py @@ -2,7 +2,7 @@ import torch from collections import defaultdict from peft import PeftModel from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union -from transformers import Trainer +from transformers import BatchEncoding, Trainer from trl import DPOTrainer from llmtuner.extras.constants import IGNORE_INDEX @@ -43,21 +43,23 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer): model: Optional[torch.nn.Module] = None, batch: Optional[Dict[str, torch.Tensor]] = None ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model) + if not torch.is_grad_enabled(): unwrapped_model.gradient_checkpointing_disable() if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model with unwrapped_model.disable_adapter(): - all_logits: torch.Tensor = self.model( - batch["input_ids"], - attention_mask=batch["attention_mask"], + all_logits = self.model( + input_ids=batch_copied["input_ids"], + attention_mask=batch_copied["attention_mask"], return_dict=True ).logits.to(torch.float32) else: - all_logits: torch.Tensor = model( - batch["input_ids"], - attention_mask=batch["attention_mask"], + all_logits = model( + input_ids=batch_copied["input_ids"], + attention_mask=batch_copied["attention_mask"], return_dict=True ).logits.to(torch.float32) diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 168fbe43..152df6e3 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -63,6 +63,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]: def gen_cmd(args: Dict[str, Any]) -> str: + args["plot_loss"] = True cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "] for k, v in args.items(): if v is not None and v != "":