From f865d0bd51336114c4f3f055caba5615a5f75efd Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 9 Sep 2023 17:04:45 +0800 Subject: [PATCH] fix lora target Former-commit-id: a51b7c98acc599de5ed2eaeeebe7b184105722c5 --- src/llmtuner/chat/stream_chat.py | 34 ++++++++++++----------- src/llmtuner/dsets/preprocess.py | 9 ++++++- src/llmtuner/extras/template.py | 4 +-- src/llmtuner/tuner/core/adapter.py | 2 +- src/llmtuner/tuner/core/utils.py | 2 +- src/llmtuner/tuner/ppo/trainer.py | 43 +++++++++++++++++------------- src/llmtuner/tuner/rm/trainer.py | 12 ++++++--- 7 files changed, 63 insertions(+), 43 deletions(-) diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 41f1f416..c6dfe30e 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -1,7 +1,7 @@ import torch from typing import Any, Dict, Generator, List, Optional, Tuple from threading import Thread -from transformers import TextIteratorStreamer +from transformers import GenerationConfig, TextIteratorStreamer from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.template import get_template_and_fix_tokenizer @@ -40,26 +40,30 @@ class ChatModel: max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) - gen_kwargs = self.generating_args.to_dict() - gen_kwargs.update(dict( - input_ids=input_ids, - do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], - temperature=temperature or gen_kwargs["temperature"], - top_p=top_p or gen_kwargs["top_p"], - top_k=top_k or gen_kwargs["top_k"], - repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], + generating_args = self.generating_args.to_dict() + generating_args.update(dict( + do_sample=do_sample if do_sample is not None else generating_args["do_sample"], + temperature=temperature or generating_args["temperature"], + top_p=top_p or generating_args["top_p"], + top_k=top_k or generating_args["top_k"], + repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, - pad_token_id=self.tokenizer.pad_token_id, - logits_processor=get_logits_processor() + pad_token_id=self.tokenizer.pad_token_id )) if max_length: - gen_kwargs.pop("max_new_tokens", None) - gen_kwargs["max_length"] = max_length + generating_args.pop("max_new_tokens", None) + generating_args["max_length"] = max_length if max_new_tokens: - gen_kwargs.pop("max_length", None) - gen_kwargs["max_new_tokens"] = max_new_tokens + generating_args.pop("max_length", None) + generating_args["max_new_tokens"] = max_new_tokens + + gen_kwargs = dict( + inputs=input_ids, + generation_config=GenerationConfig(**generating_args), + logits_processor=get_logits_processor() + ) return gen_kwargs, prompt_length diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/dsets/preprocess.py index 6c86a166..393366e6 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/dsets/preprocess.py @@ -74,7 +74,7 @@ def preprocess_dataset( if len(input_ids) + len(source_ids) + len(target_ids) > max_length: break - if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models + if turn_idx != 0 and template.efficient_eos: source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) else: source_mask = [IGNORE_INDEX] * len(source_ids) @@ -104,6 +104,9 @@ def preprocess_dataset( if len(target_ids) > data_args.max_target_length: target_ids = target_ids[:data_args.max_target_length] + if template.efficient_eos: + target_ids += [tokenizer.eos_token_id] + model_inputs["input_ids"].append(source_ids) model_inputs["attention_mask"].append([1] * len(source_ids)) model_inputs["labels"].append(target_ids) @@ -124,6 +127,10 @@ def preprocess_dataset( if len(rejected_ids) > data_args.max_target_length: rejected_ids = rejected_ids[:data_args.max_target_length] + if template.efficient_eos: + chosen_ids += [tokenizer.eos_token_id] + rejected_ids += [tokenizer.eos_token_id] + model_inputs["prompt_ids"].append(prompt_ids) model_inputs["chosen_ids"].append(chosen_ids) model_inputs["rejected_ids"].append(rejected_ids) diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index 167ef222..aa7511af 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -77,13 +77,13 @@ class Template: ) -> Tuple[List[int], List[int]]: if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): bos_ids = [tokenizer.bos_token_id] - else: # baichuan, qwen and gpt2 models has no bos token + else: # baichuan, qwen and gpt2 models have no bos token bos_ids = [] if tokenizer.eos_token_id is None: raise ValueError("EOS token is required.") - if self.efficient_eos: # used in baichuan, qwen and gpt2 models + if self.efficient_eos: # used in baichuan, qwen, chatglm, etc. eos_ids = [] else: eos_ids = [tokenizer.eos_token_id] diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/tuner/core/adapter.py index 1635a2b7..0324bc74 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/tuner/core/adapter.py @@ -82,7 +82,7 @@ def init_adapter( model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) if is_trainable and latest_checkpoint is None: # create new lora weights while training - if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target == "all": + if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": target_modules = find_all_linear_modules(model, model_args.quantization_bit) else: target_modules = finetuning_args.lora_target diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/tuner/core/utils.py index b40ce893..74ff075f 100644 --- a/src/llmtuner/tuner/core/utils.py +++ b/src/llmtuner/tuner/core/utils.py @@ -20,7 +20,7 @@ def find_all_linear_modules( module_names = set() for name, module in model.named_modules(): - if isinstance(module, linear_cls): + if output_layer_name not in name and isinstance(module, linear_cls): module_names.add(name.split(".")[-1]) if output_layer_name in module_names: diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index c21ec2ff..981e6d41 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -2,9 +2,9 @@ import os import math import torch from tqdm import tqdm -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple -from transformers import TrainerState, TrainerControl +from transformers import GenerationConfig, TrainerState, TrainerControl from trl import PPOTrainer from trl.core import LengthSampler, PPODecorators, logprobs_from_logits @@ -78,10 +78,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") # Keyword arguments for `model.generate` - gen_kwargs = self.generating_args.to_dict() - gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids - gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id - gen_kwargs["logits_processor"] = get_logits_processor() + generating_args = self.generating_args.to_dict() + generating_args.update(dict( + eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + pad_token_id=self.tokenizer.pad_token_id + )) length_sampler = LengthSampler(max_target_length // 2, max_target_length) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) @@ -103,7 +104,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): self.model.eval() # Get inputs - queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs) + queries, responses = self.get_inputs(batch, length_sampler, generating_args) self.tokenizer.padding_side = "right" # change padding side rewards = self.get_rewards(queries, responses, unwrapped_model) @@ -152,32 +153,36 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): def get_inputs( self, batch: Dict[str, torch.Tensor], - length_sampler: Optional[Callable] = None, - **generation_kwargs + length_sampler: Callable, + generating_args: Dict[str, Any] ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: r""" Generates model's responses given queries. """ - if length_sampler is not None: - generation_kwargs["max_new_tokens"] = length_sampler() + generating_args["max_new_tokens"] = length_sampler() + gen_kwargs = dict( + generation_config=GenerationConfig(**generating_args), + logits_processor=get_logits_processor(), + **batch + ) + input_ids = batch["input_ids"] unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) - response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs) - - # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop - # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 - if unwrapped_model.pretrained_model.generation_config._from_model_config: - unwrapped_model.pretrained_model.generation_config._from_model_config = False + response: torch.Tensor = unwrapped_model.generate(**gen_kwargs) + query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu() queries, responses = [], [] - query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() for i in range(len(query)): query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() + if len(response_index) == 0: response_length = 1 # allow empty response + elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: + response_length = response_index[-1] + 2 # save the EOS token else: response_length = response_index[-1] + 1 + queries.append(query[i, query_length:]) # remove padding from left responses.append(response[i, :response_length]) # remove padding from right @@ -204,7 +209,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): rewards = [] for i in range(values.size(0)): - end_index = batch["attention_mask"][i].nonzero()[-1] + end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type replace_model(unwrapped_model, target="default") diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py index 854f9792..23b33539 100644 --- a/src/llmtuner/tuner/rm/trainer.py +++ b/src/llmtuner/tuner/rm/trainer.py @@ -69,13 +69,17 @@ class PairwisePeftTrainer(PeftTrainer): assert div_index > 0 chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] - chosen_scores.append(chosen_trunc_rewards[-1]) # use the end score for inference - rejected_scores.append(rejected_trunc_rewards[-1]) + if return_outputs: # use the score on the EOS token for inference + chosen_scores.append(chosen_rewards[i, chosen_length-1]) + rejected_scores.append(rejected_rewards[i, rejected_length-1]) loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() loss = loss / batch_size - chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores) - return (loss, [loss, chosen_scores, rejected_scores]) if return_outputs else loss + if return_outputs: + chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores) + return loss, [loss, chosen_scores, rejected_scores] + + return loss def save_predictions( self,