From 5d96cf146e4858e62fcb9fe548f12b2ef2bfda99 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 3 Jun 2024 22:08:38 +0800 Subject: [PATCH] Update trainer.py Former-commit-id: 24499f40dc1d9db448a3328d2a75c60eec27feb9 --- src/llamafactory/train/kto/trainer.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 3f1220a9..7c0343f5 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -4,7 +4,6 @@ from types import MethodType from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch -from torch.utils.data import RandomSampler from transformers import Trainer from trl import KTOTrainer from trl.trainer import disable_dropout_in_model @@ -14,6 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + import torch.utils.data from transformers import PreTrainedModel, ProcessorMixin from ...hparams import FinetuningArguments @@ -85,6 +85,12 @@ class CustomKTOTrainer(KTOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + r""" + Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. + """ + return Trainer._get_train_sampler(self) + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) if self.processor is not None: @@ -174,21 +180,6 @@ class CustomKTOTrainer(KTOTrainer): return reference_chosen_logps, reference_rejected_logps, reference_kl_logps - def has_length(self,dataset): - """ - Checks if the dataset implements __len__() and it doesn't raise an error - """ - try: - return len(dataset) is not None - except TypeError: - # TypeError: len() of unsized object - return False - - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.train_dataset is None or not self.has_length(self.train_dataset): - return None - return RandomSampler(self.train_dataset) - def get_batch_loss_metrics( self, model: "PreTrainedModel",