mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
7984ae8b62
commit
e75024fde3
@ -2,7 +2,7 @@ import torch
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||||
from transformers import Trainer
|
from transformers import BatchEncoding, Trainer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from llmtuner.extras.constants import IGNORE_INDEX
|
from llmtuner.extras.constants import IGNORE_INDEX
|
||||||
@ -43,21 +43,23 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
|||||||
model: Optional[torch.nn.Module] = None,
|
model: Optional[torch.nn.Module] = None,
|
||||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||||
|
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||||
|
|
||||||
if not torch.is_grad_enabled():
|
if not torch.is_grad_enabled():
|
||||||
unwrapped_model.gradient_checkpointing_disable()
|
unwrapped_model.gradient_checkpointing_disable()
|
||||||
|
|
||||||
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
|
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
|
||||||
with unwrapped_model.disable_adapter():
|
with unwrapped_model.disable_adapter():
|
||||||
all_logits: torch.Tensor = self.model(
|
all_logits = self.model(
|
||||||
batch["input_ids"],
|
input_ids=batch_copied["input_ids"],
|
||||||
attention_mask=batch["attention_mask"],
|
attention_mask=batch_copied["attention_mask"],
|
||||||
return_dict=True
|
return_dict=True
|
||||||
).logits.to(torch.float32)
|
).logits.to(torch.float32)
|
||||||
else:
|
else:
|
||||||
all_logits: torch.Tensor = model(
|
all_logits = model(
|
||||||
batch["input_ids"],
|
input_ids=batch_copied["input_ids"],
|
||||||
attention_mask=batch["attention_mask"],
|
attention_mask=batch_copied["attention_mask"],
|
||||||
return_dict=True
|
return_dict=True
|
||||||
).logits.to(torch.float32)
|
).logits.to(torch.float32)
|
||||||
|
|
||||||
|
@ -63,6 +63,7 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
|
|
||||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||||
|
args["plot_loss"] = True
|
||||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
|
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
|
||||||
for k, v in args.items():
|
for k, v in args.items():
|
||||||
if v is not None and v != "":
|
if v is not None and v != "":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user