diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index f02dbdc3..40129840 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -from transformers.trainer_pt_utils import remove_dummy_checkpoint from trl import PPOTrainer from trl.core import PPODecorators, logprobs_from_logits @@ -361,9 +360,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) except ValueError: logger.warning( - " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" - " zero_to_fp32.py to recover weights" + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," + " use zero_to_fp32.py to recover weights" ) self._save(output_dir, state_dict={}) - remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint + file = os.path.join(output_dir, filename) + if os.path.isfile(file): + os.remove(file) + self.model.save_checkpoint(output_dir) # wrapped model