support disable shuffling

This commit is contained in:
hiyouga
2024-12-19 08:53:21 +00:00
parent 6ccd64ecd9
commit c7cedc7569
9 changed files with 139 additions and 12 deletions

View File

@@ -81,6 +81,13 @@ class PairwiseTrainer(Trainer):
create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
if self.finetuning_args.disable_shuffling:
return torch.utils.data.SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
@override
def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs