diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 3f1220a9..7c0343f5 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -4,7 +4,6 @@ from types import MethodType from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union import torch -from torch.utils.data import RandomSampler from transformers import Trainer from trl import KTOTrainer from trl.trainer import disable_dropout_in_model @@ -14,6 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + import torch.utils.data from transformers import PreTrainedModel, ProcessorMixin from ...hparams import FinetuningArguments @@ -85,6 +85,12 @@ class CustomKTOTrainer(KTOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + r""" + Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. + """ + return Trainer._get_train_sampler(self) + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) if self.processor is not None: @@ -174,21 +180,6 @@ class CustomKTOTrainer(KTOTrainer): return reference_chosen_logps, reference_rejected_logps, reference_kl_logps - def has_length(self,dataset): - """ - Checks if the dataset implements __len__() and it doesn't raise an error - """ - try: - return len(dataset) is not None - except TypeError: - # TypeError: len() of unsized object - return False - - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.train_dataset is None or not self.has_length(self.train_dataset): - return None - return RandomSampler(self.train_dataset) - def get_batch_loss_metrics( self, model: "PreTrainedModel",