mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix ppo trainer save logic
Former-commit-id: d3dccd0693ede18a99f04780f2fd6e3a89810405
This commit is contained in:
parent
58cbc9e0be
commit
027caabbb6
@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
@ -361,9 +360,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
||||
" zero_to_fp32.py to recover weights"
|
||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||
" use zero_to_fp32.py to recover weights"
|
||||
)
|
||||
self._save(output_dir, state_dict={})
|
||||
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
|
||||
file = os.path.join(output_dir, filename)
|
||||
if os.path.isfile(file):
|
||||
os.remove(file)
|
||||
|
||||
self.model.save_checkpoint(output_dir) # wrapped model
|
||||
|
Loading…
x
Reference in New Issue
Block a user