diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 82ae722d..7c0343f5 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -13,6 +13,7 @@ from ..utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: + import torch.utils.data from transformers import PreTrainedModel, ProcessorMixin from ...hparams import FinetuningArguments @@ -84,6 +85,12 @@ class CustomKTOTrainer(KTOTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + r""" + Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. + """ + return Trainer._get_train_sampler(self) + def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: super()._save(output_dir, state_dict) if self.processor is not None: