mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
Former-commit-id: 9f4c2adc9a9ca8e458d3868805e077182e0d336a
This commit is contained in:
parent
623a34b16f
commit
caf4a61e21
@ -49,7 +49,7 @@ class ChatModel:
|
|||||||
top_p=top_p or gen_kwargs["top_p"],
|
top_p=top_p or gen_kwargs["top_p"],
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
top_k=top_k or gen_kwargs["top_k"],
|
||||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)),
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
logits_processor=get_logits_processor()
|
logits_processor=get_logits_processor()
|
||||||
))
|
))
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
from transformers import (
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
LogitsProcessor,
|
|
||||||
InfNanRemoveLogitsProcessor,
|
|
||||||
LogitsProcessorList
|
|
||||||
)
|
|
||||||
|
|
||||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
|||||||
from transformers import TrainerState, TrainerControl
|
from transformers import TrainerState, TrainerControl
|
||||||
|
|
||||||
from trl import PPOTrainer
|
from trl import PPOTrainer
|
||||||
from trl.core import LengthSampler
|
from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
|
||||||
|
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
@ -35,6 +35,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
callbacks: List["LogCallback"],
|
callbacks: List["LogCallback"],
|
||||||
|
compute_dtype: torch.dtype,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
PPOTrainer.__init__(self, **kwargs)
|
PPOTrainer.__init__(self, **kwargs)
|
||||||
@ -42,6 +43,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
self.finetuning_args = finetuning_args
|
self.finetuning_args = finetuning_args
|
||||||
self.generating_args = generating_args
|
self.generating_args = generating_args
|
||||||
self.log_callback = callbacks[0]
|
self.log_callback = callbacks[0]
|
||||||
|
self.compute_dtype = compute_dtype
|
||||||
self.state = TrainerState()
|
self.state = TrainerState()
|
||||||
self.control = TrainerControl()
|
self.control = TrainerControl()
|
||||||
|
|
||||||
@ -74,7 +76,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = self.generating_args.to_dict()
|
gen_kwargs = self.generating_args.to_dict()
|
||||||
gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids
|
gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
|
||||||
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
|
|
||||||
@ -183,12 +185,74 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
|||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
if values.size(0) != batch["input_ids"].size(0):
|
if values.size(0) != batch["input_ids"].size(0): # adapt chatglm2
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
return rewards
|
return rewards
|
||||||
|
|
||||||
|
@PPODecorators.empty_cuda_cache()
|
||||||
|
def batched_forward_pass(
|
||||||
|
self,
|
||||||
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
|
queries: torch.Tensor,
|
||||||
|
responses: torch.Tensor,
|
||||||
|
model_inputs: dict,
|
||||||
|
return_logits: Optional[bool] = False
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calculates model outputs in multiple batches.
|
||||||
|
|
||||||
|
Subclass and override to inject custom behavior.
|
||||||
|
"""
|
||||||
|
bs = len(queries)
|
||||||
|
fbs = self.config.mini_batch_size
|
||||||
|
all_logprobs = []
|
||||||
|
all_logits = []
|
||||||
|
all_masks = []
|
||||||
|
all_values = []
|
||||||
|
|
||||||
|
for i in range(math.ceil(bs / fbs)):
|
||||||
|
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
||||||
|
query_batch = queries[i * fbs : (i + 1) * fbs]
|
||||||
|
response_batch = responses[i * fbs : (i + 1) * fbs]
|
||||||
|
input_ids = input_kwargs["input_ids"]
|
||||||
|
attention_mask = input_kwargs["attention_mask"]
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
|
||||||
|
logits, _, values = model(**input_kwargs)
|
||||||
|
|
||||||
|
if values.size(0) != input_ids.size(0): # adapt chatglm2
|
||||||
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
|
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||||
|
masks = torch.zeros_like(attention_mask)
|
||||||
|
masks[:, :-1] = attention_mask[:, 1:]
|
||||||
|
|
||||||
|
for j in range(len(query_batch)):
|
||||||
|
start = len(query_batch[j]) - 1
|
||||||
|
if attention_mask[j, 0] == 0: # offset left padding
|
||||||
|
start += attention_mask[j, :].nonzero()[0]
|
||||||
|
end = start + len(response_batch[j])
|
||||||
|
|
||||||
|
masks[j, :start] = 0
|
||||||
|
masks[j, end:] = 0
|
||||||
|
|
||||||
|
if return_logits:
|
||||||
|
all_logits.append(logits)
|
||||||
|
else:
|
||||||
|
del logits
|
||||||
|
all_values.append(values)
|
||||||
|
all_logprobs.append(logprobs)
|
||||||
|
all_masks.append(masks)
|
||||||
|
|
||||||
|
return (
|
||||||
|
torch.cat(all_logprobs),
|
||||||
|
torch.cat(all_logits)[:, :-1] if return_logits else None,
|
||||||
|
torch.cat(all_values)[:, :-1],
|
||||||
|
torch.cat(all_masks)[:, :-1],
|
||||||
|
)
|
||||||
|
|
||||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
Saves model checkpoint.
|
Saves model checkpoint.
|
||||||
|
@ -60,6 +60,7 @@ def run_ppo(
|
|||||||
finetuning_args=finetuning_args,
|
finetuning_args=finetuning_args,
|
||||||
generating_args=generating_args,
|
generating_args=generating_args,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
compute_dtype=model_args.compute_dtype,
|
||||||
config=ppo_config,
|
config=ppo_config,
|
||||||
model=model,
|
model=model,
|
||||||
ref_model=None,
|
ref_model=None,
|
||||||
|
@ -42,7 +42,7 @@ class PairwisePeftTrainer(PeftTrainer):
|
|||||||
"""
|
"""
|
||||||
batch_size = inputs["input_ids"].size(0) // 2
|
batch_size = inputs["input_ids"].size(0) // 2
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||||
if values.size(0) != inputs["input_ids"].size(0):
|
if values.size(0) != inputs["input_ids"].size(0): # adapt chatglm2
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||||
|
@ -52,7 +52,7 @@ def run_sft(
|
|||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = generating_args.to_dict()
|
gen_kwargs = generating_args.to_dict()
|
||||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
|
||||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user