From e1751f6398b941a510432fcfddeaade1fe6ebd4b Mon Sep 17 00:00:00 2001 From: hzhaoy Date: Thu, 27 Jun 2024 13:49:57 +0800 Subject: [PATCH] fix #4579 Former-commit-id: 677c86594e4ea904fde0a557852daf54636b06ae --- src/llamafactory/train/sft/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 8f18317f..f0a86dff 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -53,6 +53,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.processor = processor if finetuning_args.pissa_convert: + if self.is_deepspeed_enabled: + self.accelerator.deepspeed_config = self.accelerator.state.deepspeed_plugin.deepspeed_config + self.deepspeed = self._wrap_model(self.model_wrapped) self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: