Former-commit-id: c7e51ff187658eb472c2b234f75d8934c6f7c782
This commit is contained in:
hiyouga 2024-09-11 17:36:42 +08:00
parent 38505ae9e1
commit 009500bc6d
4 changed files with 12 additions and 22 deletions

View File

@ -229,8 +229,9 @@ async def create_stream_chat_completion_response(
async def create_score_evaluation_response( async def create_score_evaluation_response(
request: "ScoreEvaluationRequest", chat_model: "ChatModel" request: "ScoreEvaluationRequest", chat_model: "ChatModel"
) -> "ScoreEvaluationResponse": ) -> "ScoreEvaluationResponse":
score_id = "scoreval-{}".format(uuid.uuid4().hex)
if len(request.messages) == 0: if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
return ScoreEvaluationResponse(model=request.model, scores=scores) return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)

View File

@ -246,29 +246,18 @@ class HuggingfaceEngine(BaseEngine):
batch_input: List[str], batch_input: List[str],
input_kwargs: Optional[Dict[str, Any]] = {}, input_kwargs: Optional[Dict[str, Any]] = {},
) -> List[float]: ) -> List[float]:
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
device = getattr(model.pretrained_model, "device", "cuda") device = getattr(model.pretrained_model, "device", "cuda")
inputs = tokenizer( inputs: Dict[str, "torch.Tensor"] = tokenizer(
batch_input, batch_input,
padding=True, padding=True,
truncation=True, truncation=True,
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024), max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
return_tensors="pt", return_tensors="pt",
add_special_tokens=True, add_special_tokens=False,
).to(device) ).to(device)
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
input_ids: torch.Tensor = inputs["input_ids"] scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if getattr(model.config, "model_type", None) == "chatglm":
values = torch.transpose(values, 0, 1)
scores = []
for i in range(input_ids.size(0)):
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
end_index = end_indexes[-1].item() if len(end_indexes) else 0
scores.append(values[i, end_index].nan_to_num().item())
return scores return scores
@override @override

View File

@ -31,7 +31,7 @@ if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
r""" r"""
Gets reward scores from the API server. Gets reward scores from the API server.
""" """
@ -66,7 +66,7 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device) v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
r""" r"""
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered). Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
""" """
@ -79,7 +79,7 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
return layer_norm_params return layer_norm_params
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
r""" r"""
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered). Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
""" """

View File

@ -392,7 +392,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
""" """
if self.finetuning_args.reward_model_type == "api": if self.finetuning_args.reward_model_type == "api":
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)] token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True) messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
return get_rewards_from_server(self.reward_model, messages) return get_rewards_from_server(self.reward_model, messages)
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses) batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
@ -405,7 +405,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
reward_model = self.reward_model reward_model = self.reward_model
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16 with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
_, _, values = reward_model(**batch, return_dict=True, use_cache=False) values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
if self.finetuning_args.reward_model_type == "lora": if self.finetuning_args.reward_model_type == "lora":
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")