This commit is contained in:
hiyouga
2023-09-21 19:51:02 +08:00
parent ace3f85a72
commit 338b8664ed
11 changed files with 116 additions and 101 deletions

View File

@@ -173,7 +173,7 @@ def load_model_and_tokenizer(
)
# Disable custom generate method (for Qwen)
if "GenerationMixin" not in str(model.generate.__func__):
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)

View File

@@ -213,7 +213,7 @@ def get_train_args(
else:
model_args.compute_dtype = torch.float32
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(

View File

@@ -2,13 +2,13 @@ import os
import math
import torch
from tqdm import tqdm
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from trl.core import PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
@@ -47,7 +47,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.state = TrainerState()
self.control = TrainerControl()
def ppo_train(self, max_target_length: int) -> None:
def ppo_train(self) -> None:
r"""
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
"""
@@ -81,9 +81,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
pad_token_id=self.tokenizer.pad_token_id
))
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
loss_meter = AverageMeter()
@@ -100,7 +98,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.model.eval()
# Get inputs
queries, responses = self.get_inputs(batch, length_sampler, generating_args)
queries, responses = self.get_inputs(batch, generating_args)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)
@@ -156,13 +154,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
def get_inputs(
self,
batch: Dict[str, torch.Tensor],
length_sampler: Callable,
generating_args: Dict[str, Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
generating_args["max_new_tokens"] = length_sampler()
gen_kwargs = dict(
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),

View File

@@ -79,7 +79,7 @@ def run_ppo(
# Training
if training_args.do_train:
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
ppo_trainer.ppo_train()
ppo_trainer.save_model()
ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss: