mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
e6f4eab4ab
commit
d05d535a58
@ -10,7 +10,7 @@ from llmtuner.tuner.core.trainer import PeftModelMixin
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
||||||
@ -18,12 +18,10 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
|
||||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.generating_args = generating_args
|
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
self.use_dpo_data_collator = True # hack to avoid warning
|
self.use_dpo_data_collator = True # hack to avoid warning
|
||||||
self.label_pad_token_id = IGNORE_INDEX
|
self.label_pad_token_id = IGNORE_INDEX
|
||||||
|
@ -13,7 +13,7 @@ from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
def run_dpo(
|
def run_dpo(
|
||||||
@ -21,7 +21,6 @@ def run_dpo(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
|
||||||
callbacks: Optional[List["TrainerCallback"]] = None
|
callbacks: Optional[List["TrainerCallback"]] = None
|
||||||
):
|
):
|
||||||
dataset = get_dataset(model_args, data_args)
|
dataset = get_dataset(model_args, data_args)
|
||||||
@ -38,7 +37,6 @@ def run_dpo(
|
|||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = DPOPeftTrainer(
|
trainer = DPOPeftTrainer(
|
||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
generating_args=generating_args,
|
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
model=model,
|
model=model,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user