mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
ppo support rm server
Former-commit-id: 747db4017291b0eb91946f57011bb31659056037
This commit is contained in:
parent
1cb390b9b2
commit
64eead3fb1
@ -167,7 +167,8 @@ class ChatModel:
|
|||||||
|
|
||||||
scores = []
|
scores = []
|
||||||
for i in range(input_ids.size(0)):
|
for i in range(input_ids.size(0)):
|
||||||
length = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
end_indexes = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()
|
||||||
scores.append(values[i, length-1].nan_to_num().item())
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
|
scores.append(values[i, end_index].nan_to_num().item())
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
@ -18,6 +18,7 @@ _flash_attn2_available = is_package_available("flash_attn") and get_package_vers
|
|||||||
_jieba_available = is_package_available("jieba")
|
_jieba_available = is_package_available("jieba")
|
||||||
_matplotlib_available = is_package_available("matplotlib")
|
_matplotlib_available = is_package_available("matplotlib")
|
||||||
_nltk_available = is_package_available("nltk")
|
_nltk_available = is_package_available("nltk")
|
||||||
|
_requests_available = is_package_available("requests")
|
||||||
_rouge_available = is_package_available("rouge_chinese")
|
_rouge_available = is_package_available("rouge_chinese")
|
||||||
_starlette_available = is_package_available("sse_starlette")
|
_starlette_available = is_package_available("sse_starlette")
|
||||||
_uvicorn_available = is_package_available("uvicorn")
|
_uvicorn_available = is_package_available("uvicorn")
|
||||||
@ -43,6 +44,10 @@ def is_nltk_available():
|
|||||||
return _nltk_available
|
return _nltk_available
|
||||||
|
|
||||||
|
|
||||||
|
def is_requests_available():
|
||||||
|
return _requests_available
|
||||||
|
|
||||||
|
|
||||||
def is_rouge_available():
|
def is_rouge_available():
|
||||||
return _rouge_available
|
return _rouge_available
|
||||||
|
|
||||||
|
@ -3,9 +3,9 @@ import sys
|
|||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl
|
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||||
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
from transformers.utils import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||||
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
from transformers.trainer_pt_utils import remove_dummy_checkpoint
|
||||||
@ -16,7 +16,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
|||||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||||
from llmtuner.extras.logging import get_logger
|
from llmtuner.extras.logging import get_logger
|
||||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||||
from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
from llmtuner.train.ppo.utils import dump_layernorm, get_rewards_from_server, restore_layernorm, replace_model
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
@ -200,7 +200,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def get_inputs(self, batch: BatchEncoding) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
r"""
|
r"""
|
||||||
Generates model's responses given queries.
|
Generates model's responses given queries.
|
||||||
"""
|
"""
|
||||||
@ -208,7 +208,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
layernorm_params = dump_layernorm(self.model)
|
layernorm_params = dump_layernorm(self.model)
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||||
response: torch.Tensor = unwrapped_model.generate(
|
generate_output: torch.Tensor = unwrapped_model.generate(
|
||||||
generation_config=self.generation_config,
|
generation_config=self.generation_config,
|
||||||
logits_processor=get_logits_processor(),
|
logits_processor=get_logits_processor(),
|
||||||
**batch
|
**batch
|
||||||
@ -217,7 +217,8 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
if self.finetuning_args.upcast_layernorm:
|
if self.finetuning_args.upcast_layernorm:
|
||||||
restore_layernorm(self.model, layernorm_params)
|
restore_layernorm(self.model, layernorm_params)
|
||||||
|
|
||||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
query = batch["input_ids"].detach().cpu()
|
||||||
|
response = generate_output[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||||
queries, responses = [], []
|
queries, responses = [], []
|
||||||
for i in range(len(query)):
|
for i in range(len(query)):
|
||||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||||
@ -242,17 +243,26 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
) -> List[torch.Tensor]:
|
) -> List[torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
Computes scores using given reward model.
|
Computes scores using given reward model.
|
||||||
|
|
||||||
|
Both inputs and outputs are put on CPU.
|
||||||
"""
|
"""
|
||||||
if self.reward_model is None:
|
if self.finetuning_args.reward_model_type == "api":
|
||||||
|
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
|
||||||
|
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
||||||
|
return get_rewards_from_server(self.reward_model, messages)
|
||||||
|
|
||||||
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="reward")
|
replace_model(unwrapped_model, target="reward")
|
||||||
|
reward_model = self.model
|
||||||
|
else:
|
||||||
|
reward_model = self.reward_model
|
||||||
|
|
||||||
batch = self.prepare_model_inputs(queries, responses)
|
batch = self.prepare_model_inputs(queries, responses)
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||||
reward_model = self.reward_model if self.reward_model is not None else self.model
|
|
||||||
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
_, _, values = reward_model(**batch, output_hidden_states=True, return_dict=True)
|
||||||
|
|
||||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
if getattr(unwrapped_model.config, "model_type", None) == "chatglm": # assume same architecture
|
||||||
values = torch.transpose(values, 0, 1)
|
values = torch.transpose(values, 0, 1)
|
||||||
|
|
||||||
rewards = []
|
rewards = []
|
||||||
@ -261,7 +271,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||||
|
|
||||||
if self.reward_model is None:
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
return rewards
|
return rewards
|
||||||
|
@ -1,10 +1,24 @@
|
|||||||
|
import json
|
||||||
import torch
|
import torch
|
||||||
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from llmtuner.extras.packages import is_requests_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
if is_requests_available():
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
payload = {"model": "model", "messages": messages}
|
||||||
|
response = requests.post(server_url, json=payload, headers=headers)
|
||||||
|
rewards = json.loads(response.text)["scores"]
|
||||||
|
return torch.Tensor(rewards)
|
||||||
|
|
||||||
|
|
||||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||||
if target == "reward": # save default head temporarily
|
if target == "reward": # save default head temporarily
|
||||||
|
@ -76,7 +76,9 @@ def create_reward_model(
|
|||||||
Creates reward model for PPO training.
|
Creates reward model for PPO training.
|
||||||
"""
|
"""
|
||||||
if finetuning_args.reward_model_type == "api":
|
if finetuning_args.reward_model_type == "api":
|
||||||
raise NotImplementedError
|
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||||
|
logger.info("Use reward server {}".format(finetuning_args.reward_model))
|
||||||
|
return finetuning_args.reward_model
|
||||||
elif finetuning_args.reward_model_type == "lora":
|
elif finetuning_args.reward_model_type == "lora":
|
||||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||||
@ -102,6 +104,6 @@ def create_reward_model(
|
|||||||
reward_model, _ = load_model_and_tokenizer(
|
reward_model, _ = load_model_and_tokenizer(
|
||||||
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||||
)
|
)
|
||||||
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
|
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
|
||||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||||
return reward_model
|
return reward_model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user