fix lora target

Former-commit-id: a51b7c98acc599de5ed2eaeeebe7b184105722c5
This commit is contained in:
hiyouga 2023-09-09 17:04:45 +08:00
parent c818a7ff60
commit f865d0bd51
7 changed files with 63 additions and 43 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template_and_fix_tokenizer from llmtuner.extras.template import get_template_and_fix_tokenizer
@ -40,26 +40,30 @@ class ChatModel:
max_length = input_kwargs.pop("max_length", None) max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None)
gen_kwargs = self.generating_args.to_dict() generating_args = self.generating_args.to_dict()
gen_kwargs.update(dict( generating_args.update(dict(
input_ids=input_ids, do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], temperature=temperature or generating_args["temperature"],
temperature=temperature or gen_kwargs["temperature"], top_p=top_p or generating_args["top_p"],
top_p=top_p or gen_kwargs["top_p"], top_k=top_k or generating_args["top_k"],
top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id
logits_processor=get_logits_processor()
)) ))
if max_length: if max_length:
gen_kwargs.pop("max_new_tokens", None) generating_args.pop("max_new_tokens", None)
gen_kwargs["max_length"] = max_length generating_args["max_length"] = max_length
if max_new_tokens: if max_new_tokens:
gen_kwargs.pop("max_length", None) generating_args.pop("max_length", None)
gen_kwargs["max_new_tokens"] = max_new_tokens generating_args["max_new_tokens"] = max_new_tokens
gen_kwargs = dict(
inputs=input_ids,
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor()
)
return gen_kwargs, prompt_length return gen_kwargs, prompt_length

View File

@ -74,7 +74,7 @@ def preprocess_dataset(
if len(input_ids) + len(source_ids) + len(target_ids) > max_length: if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break break
if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models if turn_idx != 0 and template.efficient_eos:
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else: else:
source_mask = [IGNORE_INDEX] * len(source_ids) source_mask = [IGNORE_INDEX] * len(source_ids)
@ -104,6 +104,9 @@ def preprocess_dataset(
if len(target_ids) > data_args.max_target_length: if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length] target_ids = target_ids[:data_args.max_target_length]
if template.efficient_eos:
target_ids += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(source_ids) model_inputs["input_ids"].append(source_ids)
model_inputs["attention_mask"].append([1] * len(source_ids)) model_inputs["attention_mask"].append([1] * len(source_ids))
model_inputs["labels"].append(target_ids) model_inputs["labels"].append(target_ids)
@ -124,6 +127,10 @@ def preprocess_dataset(
if len(rejected_ids) > data_args.max_target_length: if len(rejected_ids) > data_args.max_target_length:
rejected_ids = rejected_ids[:data_args.max_target_length] rejected_ids = rejected_ids[:data_args.max_target_length]
if template.efficient_eos:
chosen_ids += [tokenizer.eos_token_id]
rejected_ids += [tokenizer.eos_token_id]
model_inputs["prompt_ids"].append(prompt_ids) model_inputs["prompt_ids"].append(prompt_ids)
model_inputs["chosen_ids"].append(chosen_ids) model_inputs["chosen_ids"].append(chosen_ids)
model_inputs["rejected_ids"].append(rejected_ids) model_inputs["rejected_ids"].append(rejected_ids)

View File

@ -77,13 +77,13 @@ class Template:
) -> Tuple[List[int], List[int]]: ) -> Tuple[List[int], List[int]]:
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
bos_ids = [tokenizer.bos_token_id] bos_ids = [tokenizer.bos_token_id]
else: # baichuan, qwen and gpt2 models has no bos token else: # baichuan, qwen and gpt2 models have no bos token
bos_ids = [] bos_ids = []
if tokenizer.eos_token_id is None: if tokenizer.eos_token_id is None:
raise ValueError("EOS token is required.") raise ValueError("EOS token is required.")
if self.efficient_eos: # used in baichuan, qwen and gpt2 models if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
eos_ids = [] eos_ids = []
else: else:
eos_ids = [tokenizer.eos_token_id] eos_ids = [tokenizer.eos_token_id]

View File

@ -82,7 +82,7 @@ def init_adapter(
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
if is_trainable and latest_checkpoint is None: # create new lora weights while training if is_trainable and latest_checkpoint is None: # create new lora weights while training
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target == "all": if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, model_args.quantization_bit) target_modules = find_all_linear_modules(model, model_args.quantization_bit)
else: else:
target_modules = finetuning_args.lora_target target_modules = finetuning_args.lora_target

View File

@ -20,7 +20,7 @@ def find_all_linear_modules(
module_names = set() module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if isinstance(module, linear_cls): if output_layer_name not in name and isinstance(module, linear_cls):
module_names.add(name.split(".")[-1]) module_names.add(name.split(".")[-1])
if output_layer_name in module_names: if output_layer_name in module_names:

View File

@ -2,9 +2,9 @@ import os
import math import math
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
from transformers import TrainerState, TrainerControl from transformers import GenerationConfig, TrainerState, TrainerControl
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
@ -78,10 +78,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = self.generating_args.to_dict() generating_args = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids generating_args.update(dict(
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
gen_kwargs["logits_processor"] = get_logits_processor() pad_token_id=self.tokenizer.pad_token_id
))
length_sampler = LengthSampler(max_target_length // 2, max_target_length) length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
@ -103,7 +104,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self.model.eval() self.model.eval()
# Get inputs # Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs) queries, responses = self.get_inputs(batch, length_sampler, generating_args)
self.tokenizer.padding_side = "right" # change padding side self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model) rewards = self.get_rewards(queries, responses, unwrapped_model)
@ -152,32 +153,36 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
def get_inputs( def get_inputs(
self, self,
batch: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None, length_sampler: Callable,
**generation_kwargs generating_args: Dict[str, Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r""" r"""
Generates model's responses given queries. Generates model's responses given queries.
""" """
if length_sampler is not None: generating_args["max_new_tokens"] = length_sampler()
generation_kwargs["max_new_tokens"] = length_sampler() gen_kwargs = dict(
generation_config=GenerationConfig(**generating_args),
logits_processor=get_logits_processor(),
**batch
)
input_ids = batch["input_ids"]
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs) response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False
queries, responses = [], [] queries, responses = [], []
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)): for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0: if len(response_index) == 0:
response_length = 1 # allow empty response response_length = 1 # allow empty response
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
response_length = response_index[-1] + 2 # save the EOS token
else: else:
response_length = response_index[-1] + 1 response_length = response_index[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right responses.append(response[i, :response_length]) # remove padding from right
@ -204,7 +209,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
rewards = [] rewards = []
for i in range(values.size(0)): for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1] end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")

View File

@ -69,13 +69,17 @@ class PairwisePeftTrainer(PeftTrainer):
assert div_index > 0 assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
chosen_scores.append(chosen_trunc_rewards[-1]) # use the end score for inference if return_outputs: # use the score on the EOS token for inference
rejected_scores.append(rejected_trunc_rewards[-1]) chosen_scores.append(chosen_rewards[i, chosen_length-1])
rejected_scores.append(rejected_rewards[i, rejected_length-1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size loss = loss / batch_size
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores) if return_outputs:
return (loss, [loss, chosen_scores, rejected_scores]) if return_outputs else loss chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
return loss, [loss, chosen_scores, rejected_scores]
return loss
def save_predictions( def save_predictions(
self, self,