mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 00:28:10 +08:00
fix ppo trainer save zero3 model
accelerator.get_state_dict(ds_model) should be called at all ranks Former-commit-id: 3a0f60f0aa072531e4ae5819ec00c8fa42aa0913
This commit is contained in:
parent
8692796c9b
commit
b0e5a76f4c
@ -123,9 +123,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr(
|
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||||
self.accelerator.state, "deepspeed_plugin"
|
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||||
)
|
|
||||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
|
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
|
||||||
|
|
||||||
@ -466,18 +465,28 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
Subclass and override to inject custom behavior.
|
Subclass and override to inject custom behavior.
|
||||||
"""
|
"""
|
||||||
if self.args.should_save:
|
if output_dir is None:
|
||||||
|
output_dir = self.args.output_dir
|
||||||
|
|
||||||
|
if self.is_fsdp_enabled or self.is_deepspeed_enabled:
|
||||||
try:
|
try:
|
||||||
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model))
|
state_dict = self.accelerator.get_state_dict(self.model) # must be called at all ranks
|
||||||
|
if self.args.should_save:
|
||||||
|
self._save(output_dir, state_dict=state_dict)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||||
" use zero_to_fp32.py to recover weights"
|
" use zero_to_fp32.py to recover weights"
|
||||||
)
|
)
|
||||||
|
if self.args.should_save:
|
||||||
self._save(output_dir, state_dict={})
|
self._save(output_dir, state_dict={})
|
||||||
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
# remove the dummy state_dict
|
||||||
|
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
|
||||||
self.model.save_checkpoint(output_dir)
|
self.model.save_checkpoint(output_dir)
|
||||||
|
|
||||||
if self.processor is not None:
|
elif self.args.should_save:
|
||||||
|
self._save(output_dir)
|
||||||
|
|
||||||
|
if self.processor is not None and self.args.should_save:
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||||
|
@ -10,12 +10,15 @@ from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_a
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
if is_jieba_available():
|
if is_jieba_available():
|
||||||
import jieba # type: ignore
|
import jieba # type: ignore
|
||||||
|
|
||||||
|
|
||||||
if is_nltk_available():
|
if is_nltk_available():
|
||||||
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
|
||||||
|
|
||||||
|
|
||||||
if is_rouge_available():
|
if is_rouge_available():
|
||||||
from rouge_chinese import Rouge
|
from rouge_chinese import Rouge
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user