Former-commit-id: aee634cd20e6dfdfbe2fbb47ae57f62b2da2bf9a
This commit is contained in:
hiyouga 2024-04-01 21:35:18 +08:00
parent b7468ea0a8
commit 34f1de0574
7 changed files with 37 additions and 16 deletions

View File

@ -264,8 +264,8 @@ huggingface-cli login
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.39.1 | | transformers | 4.37.2 | 4.39.2 |
| datasets | 2.14.3 | 2.17.1 | | datasets | 2.14.3 | 2.18.0 |
| accelerate | 0.27.2 | 0.28.0 | | accelerate | 0.27.2 | 0.28.0 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.9.0 | 0.10.0 |
| trl | 0.8.1 | 0.8.1 | | trl | 0.8.1 | 0.8.1 |

View File

@ -264,8 +264,8 @@ huggingface-cli login
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| python | 3.8 | 3.10 | | python | 3.8 | 3.10 |
| torch | 1.13.1 | 2.2.0 | | torch | 1.13.1 | 2.2.0 |
| transformers | 4.37.2 | 4.39.1 | | transformers | 4.37.2 | 4.39.2 |
| datasets | 2.14.3 | 2.17.1 | | datasets | 2.14.3 | 2.18.0 |
| accelerate | 0.27.2 | 0.28.0 | | accelerate | 0.27.2 | 0.28.0 |
| peft | 0.9.0 | 0.10.0 | | peft | 0.9.0 | 0.10.0 |
| trl | 0.8.1 | 0.8.1 | | trl | 0.8.1 | 0.8.1 |

View File

@ -108,12 +108,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
input_messages.append({"role": role_mapping[message.role], "content": message.content}) if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
name = message.tool_calls[0].function.name
arguments = message.tool_calls[0].function.arguments
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
else:
input_messages.append({"role": role_mapping[message.role], "content": message.content})
tool_list = request.tools tool_list = request.tools
if isinstance(tool_list, list) and len(tool_list): if isinstance(tool_list, list) and len(tool_list):
try: try:
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False) tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
except Exception: except Exception:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
else: else:

View File

@ -1,6 +1,6 @@
import time import time
from enum import Enum, unique from enum import Enum, unique
from typing import List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Literal from typing_extensions import Literal
@ -39,6 +39,17 @@ class Function(BaseModel):
arguments: str arguments: str
class FunctionDefinition(BaseModel):
name: str
description: str
parameters: Dict[str, Any]
class FunctionAvailable(BaseModel):
type: Literal["function", "code_interpreter"] = "function"
function: Optional[FunctionDefinition] = None
class FunctionCall(BaseModel): class FunctionCall(BaseModel):
id: Literal["call_default"] = "call_default" id: Literal["call_default"] = "call_default"
type: Literal["function"] = "function" type: Literal["function"] = "function"
@ -47,7 +58,8 @@ class FunctionCall(BaseModel):
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Role role: Role
content: str content: Optional[str] = None
tool_calls: Optional[List[FunctionCall]] = None
class ChatCompletionMessage(BaseModel): class ChatCompletionMessage(BaseModel):
@ -59,7 +71,7 @@ class ChatCompletionMessage(BaseModel):
class ChatCompletionRequest(BaseModel): class ChatCompletionRequest(BaseModel):
model: str model: str
messages: List[ChatMessage] messages: List[ChatMessage]
tools: list = [] tools: Optional[List[FunctionAvailable]] = None
do_sample: bool = True do_sample: bool = True
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None

View File

@ -193,6 +193,6 @@ def llama_flash_attn_forward(
def apply_llama_patch() -> None: def apply_llama_patch() -> None:
require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1") require_version("transformers==4.39.2", "To fix: pip install transformers==4.39.2")
LlamaAttention.forward = llama_torch_attn_forward LlamaAttention.forward = llama_torch_attn_forward
LlamaFlashAttention2.forward = llama_flash_attn_forward LlamaFlashAttention2.forward = llama_flash_attn_forward

View File

@ -331,7 +331,7 @@ def patch_model(
): ):
gen_config.do_sample = True gen_config.do_sample = True
if model_args.resize_vocab: if is_trainable and model_args.resize_vocab:
_resize_embedding_layer(model, tokenizer) _resize_embedding_layer(model, tokenizer)
if is_trainable: if is_trainable:

View File

@ -15,7 +15,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
for grade, hour in zip(grades, hours): for grade, hour in zip(grades, hours):
total_score += grade_to_score[grade] * hour total_score += grade_to_score[grade] * hour
total_hour += hour total_hour += hour
return total_score / total_hour return round(total_score / total_hour, 2)
def main(): def main():
@ -45,16 +45,19 @@ def main():
messages = [] messages = []
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."}) messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
result = client.chat.completions.create(messages=messages, model="test", tools=tools) result = client.chat.completions.create(messages=messages, model="test", tools=tools)
if result.choices[0].message.tool_calls is None:
raise ValueError("Cannot retrieve function call from the response.")
messages.append(result.choices[0].message)
tool_call = result.choices[0].message.tool_calls[0].function tool_call = result.choices[0].message.tool_calls[0].function
print(tool_call)
# Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa')
name, arguments = tool_call.name, json.loads(tool_call.arguments) name, arguments = tool_call.name, json.loads(tool_call.arguments)
messages.append(
{"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)}
)
tool_result = tool_map[name](**arguments) tool_result = tool_map[name](**arguments)
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)}) messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
result = client.chat.completions.create(messages=messages, model="test", tools=tools) result = client.chat.completions.create(messages=messages, model="test", tools=tools)
print(result.choices[0].message.content) print(result.choices[0].message.content)
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665. # Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42.
if __name__ == "__main__": if __name__ == "__main__":