mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
fix lora target
Former-commit-id: a51b7c98acc599de5ed2eaeeebe7b184105722c5
This commit is contained in:
parent
c818a7ff60
commit
f865d0bd51
@ -1,7 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||||
@ -40,26 +40,30 @@ class ChatModel:
|
|||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length = input_kwargs.pop("max_length", None)
|
||||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||||
|
|
||||||
gen_kwargs = self.generating_args.to_dict()
|
generating_args = self.generating_args.to_dict()
|
||||||
gen_kwargs.update(dict(
|
generating_args.update(dict(
|
||||||
input_ids=input_ids,
|
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||||
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
|
temperature=temperature or generating_args["temperature"],
|
||||||
temperature=temperature or gen_kwargs["temperature"],
|
top_p=top_p or generating_args["top_p"],
|
||||||
top_p=top_p or gen_kwargs["top_p"],
|
top_k=top_k or generating_args["top_k"],
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id
|
||||||
logits_processor=get_logits_processor()
|
|
||||||
))
|
))
|
||||||
|
|
||||||
if max_length:
|
if max_length:
|
||||||
gen_kwargs.pop("max_new_tokens", None)
|
generating_args.pop("max_new_tokens", None)
|
||||||
gen_kwargs["max_length"] = max_length
|
generating_args["max_length"] = max_length
|
||||||
|
|
||||||
if max_new_tokens:
|
if max_new_tokens:
|
||||||
gen_kwargs.pop("max_length", None)
|
generating_args.pop("max_length", None)
|
||||||
gen_kwargs["max_new_tokens"] = max_new_tokens
|
generating_args["max_new_tokens"] = max_new_tokens
|
||||||
|
|
||||||
|
gen_kwargs = dict(
|
||||||
|
inputs=input_ids,
|
||||||
|
generation_config=GenerationConfig(**generating_args),
|
||||||
|
logits_processor=get_logits_processor()
|
||||||
|
)
|
||||||
|
|
||||||
return gen_kwargs, prompt_length
|
return gen_kwargs, prompt_length
|
||||||
|
|
||||||
|
@ -74,7 +74,7 @@ def preprocess_dataset(
|
|||||||
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
||||||
break
|
break
|
||||||
|
|
||||||
if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models
|
if turn_idx != 0 and template.efficient_eos:
|
||||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
else:
|
else:
|
||||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
@ -104,6 +104,9 @@ def preprocess_dataset(
|
|||||||
if len(target_ids) > data_args.max_target_length:
|
if len(target_ids) > data_args.max_target_length:
|
||||||
target_ids = target_ids[:data_args.max_target_length]
|
target_ids = target_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
target_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
model_inputs["input_ids"].append(source_ids)
|
model_inputs["input_ids"].append(source_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(source_ids))
|
model_inputs["attention_mask"].append([1] * len(source_ids))
|
||||||
model_inputs["labels"].append(target_ids)
|
model_inputs["labels"].append(target_ids)
|
||||||
@ -124,6 +127,10 @@ def preprocess_dataset(
|
|||||||
if len(rejected_ids) > data_args.max_target_length:
|
if len(rejected_ids) > data_args.max_target_length:
|
||||||
rejected_ids = rejected_ids[:data_args.max_target_length]
|
rejected_ids = rejected_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
chosen_ids += [tokenizer.eos_token_id]
|
||||||
|
rejected_ids += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
model_inputs["prompt_ids"].append(prompt_ids)
|
model_inputs["prompt_ids"].append(prompt_ids)
|
||||||
model_inputs["chosen_ids"].append(chosen_ids)
|
model_inputs["chosen_ids"].append(chosen_ids)
|
||||||
model_inputs["rejected_ids"].append(rejected_ids)
|
model_inputs["rejected_ids"].append(rejected_ids)
|
||||||
|
@ -77,13 +77,13 @@ class Template:
|
|||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
||||||
bos_ids = [tokenizer.bos_token_id]
|
bos_ids = [tokenizer.bos_token_id]
|
||||||
else: # baichuan, qwen and gpt2 models has no bos token
|
else: # baichuan, qwen and gpt2 models have no bos token
|
||||||
bos_ids = []
|
bos_ids = []
|
||||||
|
|
||||||
if tokenizer.eos_token_id is None:
|
if tokenizer.eos_token_id is None:
|
||||||
raise ValueError("EOS token is required.")
|
raise ValueError("EOS token is required.")
|
||||||
|
|
||||||
if self.efficient_eos: # used in baichuan, qwen and gpt2 models
|
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
|
||||||
eos_ids = []
|
eos_ids = []
|
||||||
else:
|
else:
|
||||||
eos_ids = [tokenizer.eos_token_id]
|
eos_ids = [tokenizer.eos_token_id]
|
||||||
|
@ -82,7 +82,7 @@ def init_adapter(
|
|||||||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||||
|
|
||||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||||
else:
|
else:
|
||||||
target_modules = finetuning_args.lora_target
|
target_modules = finetuning_args.lora_target
|
||||||
|
@ -20,7 +20,7 @@ def find_all_linear_modules(
|
|||||||
|
|
||||||
module_names = set()
|
module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if isinstance(module, linear_cls):
|
if output_layer_name not in name and isinstance(module, linear_cls):
|
||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
if output_layer_name in module_names:
|
if output_layer_name in module_names:
|
||||||
|
@ -2,9 +2,9 @@ import os
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import TrainerState, TrainerControl
|
from transformers import GenerationConfig, TrainerState, TrainerControl
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||||
@ -78,10 +78,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = self.generating_args.to_dict()
|
generating_args = self.generating_args.to_dict()
|
||||||
gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids
|
generating_args.update(dict(
|
||||||
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
pad_token_id=self.tokenizer.pad_token_id
|
||||||
|
))
|
||||||
|
|
||||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
@ -103,7 +104,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Get inputs
|
# Get inputs
|
||||||
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
|
queries, responses = self.get_inputs(batch, length_sampler, generating_args)
|
||||||
self.tokenizer.padding_side = "right" # change padding side
|
self.tokenizer.padding_side = "right" # change padding side
|
||||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||||
|
|
||||||
@ -152,32 +153,36 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
def get_inputs(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
batch: Dict[str, torch.Tensor],
|
batch: Dict[str, torch.Tensor],
|
||||||
length_sampler: Optional[Callable] = None,
|
length_sampler: Callable,
|
||||||
**generation_kwargs
|
generating_args: Dict[str, Any]
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
if length_sampler is not None:
|
generating_args["max_new_tokens"] = length_sampler()
|
||||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
gen_kwargs = dict(
|
||||||
|
generation_config=GenerationConfig(**generating_args),
|
||||||
|
logits_processor=get_logits_processor(),
|
||||||
|
**batch
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = batch["input_ids"]
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
|
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
|
||||||
|
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
|
||||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
|
||||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
|
||||||
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
|
||||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
|
||||||
|
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
|
|
||||||
if len(response_index) == 0:
|
if len(response_index) == 0:
|
||||||
response_length = 1 # allow empty response
|
response_length = 1 # allow empty response
|
||||||
|
elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
|
||||||
|
response_length = response_index[-1] + 2 # save the EOS token
|
||||||
else:
|
else:
|
||||||
response_length = response_index[-1] + 1
|
response_length = response_index[-1] + 1
|
||||||
|
|
||||||
queries.append(query[i, query_length:]) # remove padding from left
|
queries.append(query[i, query_length:]) # remove padding from left
|
||||||
responses.append(response[i, :response_length]) # remove padding from right
|
responses.append(response[i, :response_length]) # remove padding from right
|
||||||
|
|
||||||
@ -204,7 +209,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
for i in range(values.size(0)):
|
for i in range(values.size(0)):
|
||||||
end_index = batch["attention_mask"][i].nonzero()[-1]
|
end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
@ -69,13 +69,17 @@ class PairwisePeftTrainer(PeftTrainer):
|
|||||||
assert div_index > 0
|
assert div_index > 0
|
||||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||||
chosen_scores.append(chosen_trunc_rewards[-1]) # use the end score for inference
|
if return_outputs: # use the score on the EOS token for inference
|
||||||
rejected_scores.append(rejected_trunc_rewards[-1])
|
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
||||||
|
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
||||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||||
|
|
||||||
loss = loss / batch_size
|
loss = loss / batch_size
|
||||||
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
|
if return_outputs:
|
||||||
return (loss, [loss, chosen_scores, rejected_scores]) if return_outputs else loss
|
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
|
||||||
|
return loss, [loss, chosen_scores, rejected_scores]
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user