From ae9ad13f2ad2139e01e245a42f6191bbc861c6d7 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 26 Mar 2024 23:39:56 +0800 Subject: [PATCH] fix ds optimizer Former-commit-id: 3bcd41b639899e72bcabc51d59bac8967af19899 --- src/llmtuner/train/dpo/trainer.py | 4 ++-- src/llmtuner/train/pt/trainer.py | 4 ++-- src/llmtuner/train/rm/trainer.py | 4 ++-- src/llmtuner/train/sft/trainer.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/llmtuner/train/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py index ec1cb94f..ed0fe5f1 100644 --- a/src/llmtuner/train/dpo/trainer.py +++ b/src/llmtuner/train/dpo/trainer.py @@ -64,10 +64,10 @@ class CustomDPOTrainer(DPOTrainer): self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) if self.optimizer is None: - self.create_optimizer() + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) + self.create_optimizer() self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) def sft_loss(self, chosen_logits: torch.FloatTensor, chosen_labels: torch.LongTensor) -> torch.Tensor: diff --git a/src/llmtuner/train/pt/trainer.py b/src/llmtuner/train/pt/trainer.py index 48954e57..16e3f5f0 100644 --- a/src/llmtuner/train/pt/trainer.py +++ b/src/llmtuner/train/pt/trainer.py @@ -23,8 +23,8 @@ class CustomTrainer(Trainer): self.finetuning_args = finetuning_args def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) if self.optimizer is None: - self.create_optimizer() + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) + self.create_optimizer() self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) diff --git a/src/llmtuner/train/rm/trainer.py b/src/llmtuner/train/rm/trainer.py index 4fbd2318..4f5d7190 100644 --- a/src/llmtuner/train/rm/trainer.py +++ b/src/llmtuner/train/rm/trainer.py @@ -30,10 +30,10 @@ class PairwiseTrainer(Trainer): self.can_return_loss = True # override property to return eval_loss def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) if self.optimizer is None: - self.create_optimizer() + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) + self.create_optimizer() self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) def compute_loss( diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index 8e329250..4a49bb27 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -30,10 +30,10 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.finetuning_args = finetuning_args def create_optimizer_and_scheduler(self, num_training_steps: int) -> None: - self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) if self.optimizer is None: - self.create_optimizer() + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args, num_training_steps) + self.create_optimizer() self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) def prediction_step(