diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 0ad1b8e8..2e1288e4 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -123,9 +123,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.state = TrainerState() self.control = TrainerControl() - self.is_deepspeed_enabled = self.accelerator.distributed_type == "DEEPSPEED" and hasattr( - self.accelerator.state, "deepspeed_plugin" - ) + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None self.log_callback, self.save_callback = callbacks[0], callbacks[1] assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, FixValueHeadModelCallback) @@ -466,18 +465,28 @@ class CustomPPOTrainer(PPOTrainer, Trainer): Subclass and override to inject custom behavior. """ - if self.args.should_save: + if output_dir is None: + output_dir = self.args.output_dir + + if self.is_fsdp_enabled or self.is_deepspeed_enabled: try: - self._save(output_dir, state_dict=self.accelerator.get_state_dict(self.model)) + state_dict = self.accelerator.get_state_dict(self.model) # must be called at all ranks + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) except ValueError: logger.warning( " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," " use zero_to_fp32.py to recover weights" ) - self._save(output_dir, state_dict={}) - remove_dummy_checkpoint(True, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) + if self.args.should_save: + self._save(output_dir, state_dict={}) + # remove the dummy state_dict + remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]) self.model.save_checkpoint(output_dir) - if self.processor is not None: - output_dir = output_dir if output_dir is not None else self.args.output_dir - getattr(self.processor, "image_processor").save_pretrained(output_dir) + elif self.args.should_save: + self._save(output_dir) + + if self.processor is not None and self.args.should_save: + output_dir = output_dir if output_dir is not None else self.args.output_dir + getattr(self.processor, "image_processor").save_pretrained(output_dir) diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index d1af4c17..b135fcfb 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -10,12 +10,15 @@ from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_a if TYPE_CHECKING: from transformers.tokenization_utils import PreTrainedTokenizer + if is_jieba_available(): import jieba # type: ignore + if is_nltk_available(): from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + if is_rouge_available(): from rouge_chinese import Rouge