fix ppo trainer save zero3 model

accelerator.get_state_dict(ds_model) should be called at all ranks


Former-commit-id: 4489d73ac75c8dbc002fc16c854148994d432c3a
This commit is contained in:
hiyouga 2024-06-07 05:14:19 +08:00
parent f76d427332
commit 4f3c89a6eb
2 changed files with 22 additions and 10 deletions

View File

@ -123,9 +123,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl() self.control = TrainerControl()
self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
self.accelerator.state, "deepspeed_plugin" self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
)
self.log_callback, self.save_callback = callbacks[0], callbacks[1] self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback) assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback)
@ -466,18 +465,28 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
if self.args.should_save: if output_dir is None:
output_dir = self.args.output_dir
if self.is_fsdp_enabled or self.is_deepspeed_enabled:
try: try:
self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) state_dict = self.accelerator.get_state_dict(self.model) # must be called at all ranks
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
except ValueError: except ValueError:
logger.warning( logger.warning(
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
" use zero_to_fp32.py to recover weights" " use zero_to_fp32.py to recover weights"
) )
if self.args.should_save:
self._save(output_dir, state_dict={}) self._save(output_dir, state_dict={})
remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) # remove the dummy state_dict
remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
self.model.save_checkpoint(output_dir) self.model.save_checkpoint(output_dir)
if self.processor is not None: elif self.args.should_save:
self._save(output_dir)
if self.processor is not None and self.args.should_save:
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
getattr(self.processor, "image_processor").save_pretrained(output_dir) getattr(self.processor, "image_processor").save_pretrained(output_dir)

View File

@ -10,12 +10,15 @@ from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_a
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
if is_jieba_available(): if is_jieba_available():
import jieba # type: ignore import jieba # type: ignore
if is_nltk_available(): if is_nltk_available():
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
if is_rouge_available(): if is_rouge_available():
from rouge_chinese import Rouge from rouge_chinese import Rouge