This commit is contained in:
hiyouga
2023-09-21 19:51:02 +08:00
parent ace3f85a72
commit 338b8664ed
11 changed files with 116 additions and 101 deletions

View File

@@ -2,13 +2,13 @@ import os
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
@@ -47,7 +47,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.state = TrainerState()
self.control = TrainerControl()
def ppo_train(self, max_target_length: int) -> None:
def ppo_train(self) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
@@ -81,9 +81,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
pad_token_id=self.tokenizer.pad_token_id
))
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter = AverageMeter()
@@ -100,7 +98,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, generating_args)
queries, responses = self.get_inputs(batch, generating_args)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
@@ -156,13 +154,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
def get_inputs(
self,
batch: Dict[str, torch.Tensor],
length_sampler: Callable,
generating_args: Dict[str, Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
generating_args["max_new_tokens"] = length_sampler()
gen_kwargs = dict(
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),