mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
parent
38505ae9e1
commit
009500bc6d
@ -229,8 +229,9 @@ async def create_stream_chat_completion_response(
|
|||||||
async def create_score_evaluation_response(
|
async def create_score_evaluation_response(
|
||||||
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
request: "ScoreEvaluationRequest", chat_model: "ChatModel"
|
||||||
) -> "ScoreEvaluationResponse":
|
) -> "ScoreEvaluationResponse":
|
||||||
|
score_id = "scoreval-{}".format(uuid.uuid4().hex)
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||||
|
|
||||||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length)
|
||||||
return ScoreEvaluationResponse(model=request.model, scores=scores)
|
return ScoreEvaluationResponse(id=score_id, model=request.model, scores=scores)
|
||||||
|
@ -246,29 +246,18 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
input_kwargs: Optional[Dict[str, Any]] = {},
|
input_kwargs: Optional[Dict[str, Any]] = {},
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
max_length = input_kwargs.pop("max_length", None)
|
max_length: Optional[int] = input_kwargs.pop("max_length", None)
|
||||||
device = getattr(model.pretrained_model, "device", "cuda")
|
device = getattr(model.pretrained_model, "device", "cuda")
|
||||||
inputs = tokenizer(
|
inputs: Dict[str, "torch.Tensor"] = tokenizer(
|
||||||
batch_input,
|
batch_input,
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
max_length=max_length or getattr(model.config, "max_position_embeddings", 1024),
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
add_special_tokens=True,
|
add_special_tokens=False,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
values: "torch.Tensor" = model(**inputs, return_dict=True, use_cache=False)[-1]
|
||||||
input_ids: torch.Tensor = inputs["input_ids"]
|
scores = values.gather(dim=-1, index=(inputs["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
|
||||||
|
|
||||||
if getattr(model.config, "model_type", None) == "chatglm":
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for i in range(input_ids.size(0)):
|
|
||||||
end_indexes = (input_ids[i] != tokenizer.pad_token_id).nonzero()
|
|
||||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
|
||||||
scores.append(values[i, end_index].nan_to_num().item())
|
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
|||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
|
||||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
def get_rewards_from_server(server_url: str, messages: List[str]) -> List["torch.Tensor"]:
|
||||||
r"""
|
r"""
|
||||||
Gets reward scores from the API server.
|
Gets reward scores from the API server.
|
||||||
"""
|
"""
|
||||||
@ -66,7 +66,7 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
|
|||||||
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
v_head_layer.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone().to(device)
|
||||||
|
|
||||||
|
|
||||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, "torch.Tensor"]:
|
||||||
r"""
|
r"""
|
||||||
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
Dumps the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||||
"""
|
"""
|
||||||
@ -79,7 +79,7 @@ def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
|||||||
return layer_norm_params
|
return layer_norm_params
|
||||||
|
|
||||||
|
|
||||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
Restores the layernorm parameters in the model. The model is already unwrapped (and gathered).
|
||||||
"""
|
"""
|
||||||
|
@ -392,7 +392,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
"""
|
"""
|
||||||
if self.finetuning_args.reward_model_type == "api":
|
if self.finetuning_args.reward_model_type == "api":
|
||||||
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
|
token_ids = [torch.cat((q, r), dim=-1).tolist() for q, r in zip(queries, responses)]
|
||||||
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=True)
|
messages = self.tokenizer.batch_decode(token_ids, skip_special_tokens=False)
|
||||||
return get_rewards_from_server(self.reward_model, messages)
|
return get_rewards_from_server(self.reward_model, messages)
|
||||||
|
|
||||||
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
|
batch: Dict[str, "torch.Tensor"] = self.prepare_model_inputs(queries, responses)
|
||||||
@ -405,7 +405,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
reward_model = self.reward_model
|
reward_model = self.reward_model
|
||||||
|
|
||||||
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
with unwrap_model_for_generation(reward_model, self.accelerator), self.amp_context: # support bf16
|
||||||
_, _, values = reward_model(**batch, return_dict=True, use_cache=False)
|
values: "torch.Tensor" = reward_model(**batch, return_dict=True, use_cache=False)[-1]
|
||||||
|
|
||||||
if self.finetuning_args.reward_model_type == "lora":
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user