fix ppo trainer save logic

Former-commit-id: d3dccd0693ede18a99f04780f2fd6e3a89810405
This commit is contained in:
hiyouga 2023-12-04 19:00:19 +08:00
parent 58cbc9e0be
commit 027caabbb6

View File

@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.trainer_pt_utils import remove_dummy_checkpoint
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
@ -361,9 +360,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
except ValueError: except ValueError:
logger.warning( logger.warning(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" zero_to_fp32.py to recover weights" " use zero_to_fp32.py to recover weights"
) )
self._save(output_dir, state_dict={}) self._save(output_dir, state_dict={})
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
file = os.path.join(output_dir, filename)
if os.path.isfile(file):
os.remove(file)
self.model.save_checkpoint(output_dir) # wrapped model self.model.save_checkpoint(output_dir) # wrapped model