mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
fix #944
This commit is contained in:
@@ -173,7 +173,7 @@ def load_model_and_tokenizer(
|
||||
)
|
||||
|
||||
# Disable custom generate method (for Qwen)
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2)
|
||||
|
||||
@@ -213,7 +213,7 @@ def get_train_args(
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
|
||||
@@ -2,13 +2,13 @@ import os
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
@@ -47,7 +47,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
def ppo_train(self) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
@@ -81,9 +81,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
pad_token_id=self.tokenizer.pad_token_id
|
||||
))
|
||||
|
||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
loss_meter = AverageMeter()
|
||||
@@ -100,7 +98,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.model.eval()
|
||||
|
||||
# Get inputs
|
||||
queries, responses = self.get_inputs(batch, length_sampler, generating_args)
|
||||
queries, responses = self.get_inputs(batch, generating_args)
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||
|
||||
@@ -156,13 +154,11 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
def get_inputs(
|
||||
self,
|
||||
batch: Dict[str, torch.Tensor],
|
||||
length_sampler: Callable,
|
||||
generating_args: Dict[str, Any]
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
generating_args["max_new_tokens"] = length_sampler()
|
||||
gen_kwargs = dict(
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor(),
|
||||
|
||||
@@ -79,7 +79,7 @@ def run_ppo(
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.ppo_train()
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
|
||||
Reference in New Issue
Block a user