mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
10x generate in ppo w/ zero3
https://github.com/huggingface/trl/pull/1483 Former-commit-id: 65cd8bdbdbe1b19250ecd813aeb72c8e00ef2f9c
This commit is contained in:
parent
bfac965f9c
commit
468d0e7ed1
@ -13,6 +13,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
|
||||
from ...extras.callbacks import FixValueHeadModelCallback, LogCallback
|
||||
from ...extras.logging import get_logger
|
||||
@ -322,7 +323,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
for k, v in batch.items():
|
||||
batch[k] = v[:, start_index:]
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user