10x generate in ppo w/ zero3

https://github.com/huggingface/trl/pull/1483

Former-commit-id: 65cd8bdbdbe1b19250ecd813aeb72c8e00ef2f9c
This commit is contained in:
hiyouga 2024-05-29 00:23:23 +08:00
parent bfac965f9c
commit 468d0e7ed1

View File

@ -13,6 +13,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
from ...extras.logging import get_logger from ...extras.logging import get_logger
@ -322,10 +323,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for k, v in batch.items(): for k, v in batch.items():
batch[k] = v[:, start_index:] batch[k] = v[:, start_index:]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
generate_output: torch.Tensor = unwrapped_model.generate( generate_output: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
) )
if self.model_args.upcast_layernorm: if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params) restore_layernorm(self.model, layernorm_params)