mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 12:50:38 +08:00
support disable shuffling
This commit is contained in:
@@ -81,6 +81,13 @@ class PairwiseTrainer(Trainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
if self.finetuning_args.disable_shuffling:
|
||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||
|
||||
return super()._get_train_sampler()
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||
|
||||
Reference in New Issue
Block a user