diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index 1bba733b..b81aa7ff 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -6,7 +6,9 @@ from tqdm import tqdm from typing import TYPE_CHECKING, List, Optional, Tuple from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl +from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.trainer_pt_utils import remove_dummy_checkpoint from trl import PPOTrainer from trl.core import PPODecorators, logprobs_from_logits @@ -55,6 +57,9 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.state = TrainerState() self.control = TrainerControl() + self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( + self.accelerator.state, "deepspeed_plugin" + ) self.log_callback, self.save_callback = callbacks[0], callbacks[1] assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) @@ -62,10 +67,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): logger.info("max_steps is given, it will override any value given in num_train_epochs") if reward_model is not None: - is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( - self.accelerator.state, "deepspeed_plugin" - ) - if is_deepspeed_enabled: + if self.is_deepspeed_enabled: if not ( getattr(reward_model.pretrained_model, "is_loaded_in_8bit", False) or getattr(reward_model.pretrained_model, "is_loaded_in_4bit", False) @@ -345,4 +347,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer): Subclass and override to inject custom behavior. """ if self.args.should_save: - self._save(output_dir) + try: + self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) + except ValueError: + logger.warning( + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + self._save(output_dir, state_dict={}) + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + self.model.save_checkpoint(output_dir) # wrapped model