From 89d9dd5aa558a78d189462dc65549e6cddb16143 Mon Sep 17 00:00:00 2001 From: hzhaoy Date: Thu, 27 Jun 2024 13:49:57 +0800 Subject: [PATCH] fix #4579 Former-commit-id: 0fa298ff6a4febea36ea9f11c7594277a77e6e9b --- src/llamafactory/train/sft/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 8f18317f..f0a86dff 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -53,6 +53,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.processor = processor if finetuning_args.pissa_convert: + if self.is_deepspeed_enabled: + self.accelerator.deepspeed_config = self.accelerator.state.deepspeed_plugin.deepspeed_config + self.deepspeed = self._wrap_model(self.model_wrapped) self.save_model(os.path.join(self.args.output_dir, "pissa_init")) if finetuning_args.use_badam: