From 468d0e7ed1db9228f5b74b90965332c097940007 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Wed, 29 May 2024 00:23:23 +0800 Subject: [PATCH] 10x generate in ppo w/ zero3 https://github.com/huggingface/trl/pull/1483 Former-commit-id: 65cd8bdbdbe1b19250ecd813aeb72c8e00ef2f9c --- src/llamafactory/train/ppo/trainer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 985664b7..27353c72 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -13,6 +13,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from trl import PPOConfig, PPOTrainer from trl.core import PPODecorators, logprobs_from_logits +from trl.models.utils import unwrap_model_for_generation from ...extras.callbacks import FixValueHeadModelCallback, LogCallback from ...extras.logging import get_logger @@ -322,10 +323,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer): for k, v in batch.items(): batch[k] = v[:, start_index:] - unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - generate_output: torch.Tensor = unwrapped_model.generate( - generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch - ) + with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: + generate_output: torch.Tensor = unwrapped_model.generate( + generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch + ) if self.model_args.upcast_layernorm: restore_layernorm(self.model, layernorm_params)