Former-commit-id: 2f2fd55d8175eb3c6ce94bc821ab4e6331f79d8e
This commit is contained in:
hiyouga 2023-08-14 00:23:56 +08:00
parent 7984ae8b62
commit e75024fde3
2 changed files with 10 additions and 7 deletions

View File

@ -2,7 +2,7 @@ import torch
from collections import defaultdict from collections import defaultdict
from peft import PeftModel from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import Trainer from transformers import BatchEncoding, Trainer
from trl import DPOTrainer from trl import DPOTrainer
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
@ -43,21 +43,23 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
model: Optional[torch.nn.Module] = None, model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model) unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled(): if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable() unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter(): with unwrapped_model.disable_adapter():
all_logits: torch.Tensor = self.model( all_logits = self.model(
batch["input_ids"], input_ids=batch_copied["input_ids"],
attention_mask=batch["attention_mask"], attention_mask=batch_copied["attention_mask"],
return_dict=True return_dict=True
).logits.to(torch.float32) ).logits.to(torch.float32)
else: else:
all_logits: torch.Tensor = model( all_logits = model(
batch["input_ids"], input_ids=batch_copied["input_ids"],
attention_mask=batch["attention_mask"], attention_mask=batch_copied["attention_mask"],
return_dict=True return_dict=True
).logits.to(torch.float32) ).logits.to(torch.float32)

View File

@ -63,6 +63,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
def gen_cmd(args: Dict[str, Any]) -> str: def gen_cmd(args: Dict[str, Any]) -> str:
args["plot_loss"] = True
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "] cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
for k, v in args.items(): for k, v in args.items():
if v is not None and v != "": if v is not None and v != "":