mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix ppo trainer save logic
Former-commit-id: d3dccd0693ede18a99f04780f2fd6e3a89810405
This commit is contained in:
parent
58cbc9e0be
commit
027caabbb6
@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
|||||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
@ -361,9 +360,13 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||||
" zero_to_fp32.py to recover weights"
|
" use zero_to_fp32.py to recover weights"
|
||||||
)
|
)
|
||||||
self._save(output_dir, state_dict={})
|
self._save(output_dir, state_dict={})
|
||||||
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: # remove dummy checkpoint
|
||||||
|
file = os.path.join(output_dir, filename)
|
||||||
|
if os.path.isfile(file):
|
||||||
|
os.remove(file)
|
||||||
|
|
||||||
self.model.save_checkpoint(output_dir) # wrapped model
|
self.model.save_checkpoint(output_dir) # wrapped model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user