implement rm server #1543

Former-commit-id: 7df4f3ab206fddb462f6ed865eaf04234fd72ed6
This commit is contained in:
hiyouga 2023-12-03 20:52:54 +08:00
parent 2279b1948e
commit 1cb390b9b2
11 changed files with 104 additions and 24 deletions

View File

@ -15,7 +15,9 @@ from llmtuner.api.protocol import (
ChatCompletionStreamResponse, ChatCompletionStreamResponse,
ChatCompletionResponseChoice, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionResponseStreamChoice,
ChatCompletionResponseUsage ChatCompletionResponseUsage,
ScoreEvaluationRequest,
ScoreEvaluationResponse
) )
from llmtuner.chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
@ -68,6 +70,9 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if not chat_model.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0 or request.messages[-1].role != Role.USER: if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
@ -156,6 +161,17 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
yield to_json(chunk) yield to_json(chunk)
yield "[DONE]" yield "[DONE]"
@app.post("/v1/score/evaluation", response_model=ScoreEvaluationResponse, status_code=status.HTTP_200_OK)
async def create_score_evaluation(request: ScoreEvaluationRequest):
if chat_model.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = chat_model.get_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores)
return app return app

View File

@ -81,3 +81,16 @@ class ChatCompletionStreamResponse(BaseModel):
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]
class ScoreEvaluationRequest(BaseModel):
model: str
messages: List[str]
max_length: Optional[int] = None
class ScoreEvaluationResponse(BaseModel):
id: Optional[str] = "scoreeval-default"
object: Optional[str] = "score.evaluation"
model: str
scores: List[float]

View File

@ -1,4 +1,5 @@
import torch import torch
import tiktoken
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread from threading import Thread
@ -22,8 +23,11 @@ class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.can_generate = (finetuning_args.stage == "sft")
self.tokenizer.padding_side = "left" self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.system_prompt = data_args.system_prompt self.system_prompt = data_args.system_prompt
@ -130,3 +134,40 @@ class ChatModel:
thread.start() thread.start()
yield from streamer yield from streamer
@torch.inference_mode()
def get_scores(
self,
batch_input: List[str],
**input_kwargs
) -> List[float]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
inputs = self.tokenizer(
batch_input,
padding=True,
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
pad_to_multiple_of=8,
return_tensors="pt",
**kwargs
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]
_, _, values = self.model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(self.model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
length = (input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
scores.append(values[i, length-1].nan_to_num().item())
return scores

View File

@ -118,9 +118,9 @@ class RLHFArguments:
default=None, default=None,
metadata={"help": "The number of bits to quantize the reward model."} metadata={"help": "The number of bits to quantize the reward model."}
) )
reward_model_type: Optional[Literal["lora", "full"]] = field( reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
default="lora", default="lora",
metadata={"help": "The checkpoint type of the reward model. The lora type only supports lora training."} metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
) )

View File

@ -49,7 +49,7 @@ def load_model_and_tokenizer(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" add_valuehead: Optional[bool] = False
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]: ) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r""" r"""
Loads pretrained model and tokenizer. Loads pretrained model and tokenizer.
@ -205,10 +205,9 @@ def load_model_and_tokenizer(
# Initialize adapters # Initialize adapters
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable) model = init_adapter(model, model_args, finetuning_args, is_trainable)
model = model.train() if is_trainable else model.eval()
# Prepare model with valuehead for RLHF # Prepare model with valuehead for RLHF
if stage in ["rm", "ppo"]: if add_valuehead:
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name]) setattr(model, "_keys_to_ignore_on_save", [name for name, _ in model.named_parameters() if "pretrained_model" in name])
setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method setattr(model, "tie_weights", MethodType(lambda _: None, model)) # use empty method
@ -224,6 +223,9 @@ def load_model_and_tokenizer(
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
model.eval()
else:
model.train()
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(

View File

@ -25,11 +25,11 @@ def run_dpo(
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding( data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=4, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
) )
@ -37,7 +37,7 @@ def run_dpo(
if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
ref_model = model ref_model = model
else: else:
ref_model = create_ref_model(model_args, finetuning_args, stage="dpo") ref_model = create_ref_model(model_args, finetuning_args)
# Update arguments # Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()

View File

@ -28,14 +28,14 @@ def run_ppo(
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Create reference model and reward model # Create reference model and reward model
ref_model = create_ref_model(model_args, finetuning_args, stage="ppo") ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True)
reward_model = create_reward_model(model, model_args, finetuning_args) reward_model = create_reward_model(model, model_args, finetuning_args)
# Create ppo config # Create ppo config

View File

@ -22,7 +22,7 @@ def run_pt(
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

View File

@ -25,9 +25,9 @@ def run_rm(
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, add_valuehead=True)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4) data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
# Update arguments # Update arguments
training_args_dict = training_args.to_dict() training_args_dict = training_args.to_dict()

View File

@ -26,7 +26,7 @@ def run_sft(
callbacks: Optional[List["TrainerCallback"]] = None callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
if training_args.predict_with_generate: if training_args.predict_with_generate:
@ -34,7 +34,7 @@ def run_sft(
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
) )

View File

@ -1,5 +1,5 @@
import torch import torch
from typing import TYPE_CHECKING, Literal, Union from typing import TYPE_CHECKING, Optional, Union
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ModelArguments, FinetuningArguments from llmtuner.hparams import ModelArguments, FinetuningArguments
@ -35,7 +35,7 @@ def create_modelcard_and_push(
def create_ref_model( def create_ref_model(
model_args: "ModelArguments", model_args: "ModelArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
stage: Literal["ppo", "dpo"] add_valuehead: Optional[bool] = False
) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]: ) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
r""" r"""
Creates reference model for PPO/DPO training. Evaluation mode is not supported. Creates reference model for PPO/DPO training. Evaluation mode is not supported.
@ -51,13 +51,17 @@ def create_ref_model(
)) ))
ref_model_args = ModelArguments(**ref_model_args_dict) ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora") ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(ref_model_args, ref_finetuning_args, is_trainable=False, stage=stage) ref_model, _ = load_model_and_tokenizer(
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from {}".format(finetuning_args.ref_model)) logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else: else:
if finetuning_args.finetuning_type == "lora": if finetuning_args.finetuning_type == "lora":
ref_model = None ref_model = None
else: else:
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage=stage) ref_model, _ = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from the model itself.") logger.info("Created reference model from the model itself.")
return ref_model return ref_model
@ -71,7 +75,9 @@ def create_reward_model(
r""" r"""
Creates reward model for PPO training. Creates reward model for PPO training.
""" """
if finetuning_args.reward_model_type == "lora": if finetuning_args.reward_model_type == "api":
raise NotImplementedError
elif finetuning_args.reward_model_type == "lora":
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
if "default" in name: if "default" in name:
@ -93,7 +99,9 @@ def create_reward_model(
)) ))
reward_model_args = ModelArguments(**reward_model_args_dict) reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora") reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(reward_model_args, reward_finetuning_args, is_trainable=False, stage="ppo") reward_model, _ = load_model_and_tokenizer(
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model)) logger.info("Load full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
return reward_model return reward_model