From 3d291a82d34a09d0fc221a246962bcbb376b4c6b Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 30 Nov 2023 21:47:06 +0800 Subject: [PATCH] fix #1597 Former-commit-id: 327d7f7efe1fefe4bf4646c07fc4917a42c13383 --- src/llmtuner/train/ppo/trainer.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index 1bba733b..b81aa7ff 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -6,7 +6,9 @@ from tqdm import tqdm from typing import TYPE_CHECKING, List, Optional, Tuple from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl +from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.trainer_pt_utils import remove_dummy_checkpoint from trl import PPOTrainer from trl.core import PPODecorators, logprobs_from_logits @@ -55,6 +57,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.state = TrainerState() self.control = TrainerControl() + self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( + self.accelerator.state, "deepspeed_plugin" + ) self.log_callback, self.save_callback = callbacks[0], callbacks[1] assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) @@ -62,10 +67,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): logger.info("max_steps is given, it will override any value given in num_train_epochs") if reward_model is not None: - is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( - self.accelerator.state, "deepspeed_plugin" - ) - if is_deepspeed_enabled: + if self.is_deepspeed_enabled: if not ( getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False) or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False) @@ -345,4 +347,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer): Subclass and override to inject custom behavior. """ if self.args.should_save: - self._save(output_dir) + try: + self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) + except ValueError: + logger.warning( + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + self._save(output_dir, state_dict={}) + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + self.model.save_checkpoint(output_dir) # wrapped model