diff --git a/src/llmtuner/tuner/core/trainer.py b/src/llmtuner/tuner/core/trainer.py index 3d6a2d4b..9541805f 100644 --- a/src/llmtuner/tuner/core/trainer.py +++ b/src/llmtuner/tuner/core/trainer.py @@ -47,7 +47,6 @@ class PeftTrainer(Seq2SeqTrainer): logger.info(f"Saving model checkpoint to {output_dir}") model = unwrap_model(self.model) - if isinstance(model, PreTrainedModelWrapper): # Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200 model_state_dict = state_dict or model.state_dict() diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index d3f79f05..d1f47850 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -2,10 +2,9 @@ import os import math import torch from tqdm import tqdm -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple from transformers import TrainerState, TrainerControl -from transformers.modeling_utils import PreTrainedModel from trl import PPOTrainer from trl.core import LengthSampler @@ -18,6 +17,7 @@ from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments + from trl import AutoModelForCausalLMWithValueHead from llmtuner.extras.callbacks import LogCallback from llmtuner.hparams import FinetuningArguments @@ -43,7 +43,6 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): self.log_callback = callbacks[0] self.state = TrainerState() self.control = TrainerControl() - self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer self._remove_log() def ppo_train(self, max_target_length: int) -> None: @@ -83,7 +82,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): "logits_processor": get_logits_processor() } length_sampler = LengthSampler(max_target_length // 2, max_target_length) - unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model) + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) dataiter = iter(self.dataloader) steps_trained = 0 @@ -95,38 +94,22 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): batch = next(dataiter) steps_trained += 1 + # Cast to inference mode unwrapped_model.gradient_checkpointing_disable() unwrapped_model.config.use_cache = True + unwrapped_model.eval() - # Get responses - query_tensors = batch["input_ids"] - response_tensors = self.generate( - batch, length_sampler, return_prompt=False, **gen_kwargs - ).detach().cpu() # move to cpu + # Get inputs + queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs) + rewards = self.get_rewards(queries, responses, unwrapped_model) - queries, responses = [], [] - for i in range(len(query_tensors)): - query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0] - response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 - queries.append(query_tensors[i, query_length:]) # remove padding from left - responses.append(response_tensors[i, :response_length]) # remove padding from right - - # Compute rewards - replace_model(unwrapped_model, target="reward") - with torch.no_grad(): - _, _, values: torch.Tensor = self.model( - **self.prepare_model_inputs(queries, responses), - output_hidden_states=True, - return_dict=True - ) - rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type - replace_model(unwrapped_model, target="default") - - # Run PPO step + # Cast to training mode unwrapped_model.gradient_checkpointing_enable() unwrapped_model.config.use_cache = False - stats = self.step(queries, responses, rewards) + unwrapped_model.train() + # Run PPO step + stats = self.step(queries, responses, rewards) loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) @@ -155,37 +138,57 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): steps_trained = 0 @torch.no_grad() - def generate( + def get_inputs( self, inputs: Dict[str, torch.Tensor], length_sampler: Optional[Callable] = None, - return_prompt: Optional[bool] = True, **generation_kwargs - ) -> torch.Tensor: + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: r""" Generates model's responses given queries. - - Subclass and override to inject custom behavior. """ - self.model, layer_norm_params = cast_layernorm_dtype(self.model) - if length_sampler is not None: generation_kwargs["max_new_tokens"] = length_sampler() - unwrapped_model = self.accelerator.unwrap_model(self.model) - - response = unwrapped_model.generate(**inputs, **generation_kwargs) + self.model, layer_norm_params = cast_layernorm_dtype(self.model) + unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) + response: torch.Tensor = unwrapped_model.generate(**inputs, **generation_kwargs) + self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 if unwrapped_model.pretrained_model.generation_config._from_model_config: unwrapped_model.pretrained_model.generation_config._from_model_config = False - self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) + queries, responses = [], [] + query, response = inputs["input_ids"], response[:, inputs["input_ids"].size(-1):].detach().cpu() + for i in range(len(query)): + query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] + response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 + queries.append(query[i, query_length:]) # remove padding from left + responses.append(response[i, :response_length]) # remove padding from right - if not return_prompt and not self.is_encoder_decoder: - return response[:, inputs["input_ids"].size(1):] - return response + return queries, responses + + @torch.no_grad() + def get_rewards( + self, + queries: List[torch.Tensor], + responses: List[torch.Tensor], + unwrapped_model: "AutoModelForCausalLMWithValueHead" + ) -> List[torch.Tensor]: + r""" + Computes scores using given reward model. + """ + replace_model(unwrapped_model, target="reward") + _, _, values = self.model( + **self.prepare_model_inputs(queries, responses), + output_hidden_states=True, + return_dict=True + ) + rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type + replace_model(unwrapped_model, target="default") + return rewards def save_model(self, output_dir: Optional[str] = None) -> None: r"""