mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
support DPO training (2305.18290)
This commit is contained in:
@@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PeftTrainer(Seq2SeqTrainer):
|
||||
class PeftModelMixin:
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self._remove_log()
|
||||
|
||||
def _remove_log(self):
|
||||
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||
def __init__(self) -> None: # for type checking
|
||||
self.model: PreTrainedModel = None
|
||||
self.tokenizer: "PreTrainedTokenizer" = None
|
||||
self.args: "Seq2SeqTrainingArguments" = None
|
||||
self.finetuning_args: "FinetuningArguments" = None
|
||||
self.state: "TrainerState" = None
|
||||
raise AssertionError("Mixin should not be initialized.")
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
@@ -96,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
|
||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
|
||||
|
||||
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
Seq2SeqTrainer.__init__(self, **kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
|
||||
Reference in New Issue
Block a user