Former-commit-id: f51d48ea081001ff2338a6a78231bfb13cfb0465
This commit is contained in:
hiyouga 2024-01-21 19:15:27 +08:00
parent fb2d563be5
commit daea73cf2b
3 changed files with 48 additions and 2 deletions

View File

@ -7,5 +7,5 @@ from .train import export_model, run_exp
from .webui import create_ui, create_web_demo from .webui import create_ui, create_web_demo
__version__ = "0.5.0" __version__ = "0.5.1"
__all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"] __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"]

View File

@ -85,7 +85,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if not chat_model.can_generate: if not chat_model.can_generate:
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
if len(request.messages) == 0 or request.messages[-1].role != Role.USER: if len(request.messages) == 0 or request.messages[-1].role not in [Role.USER, Role.TOOL]:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
messages = [dictify(message) for message in request.messages] messages = [dictify(message) for message in request.messages]

46
tests/test_toolcall.py Normal file
View File

@ -0,0 +1,46 @@
import os
from openai import OpenAI
os.environ["OPENAI_BASE_URL"] = "http://192.168.5.193:8000/v1"
os.environ["OPENAI_API_KEY"] = "0"
if __name__ == "__main__":
client = OpenAI()
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}
},
"required": ["location"],
},
},
}
]
result = client.chat.completions.create(
messages=[{"role": "user", "content": "What is the weather like in Boston?"}],
model="gpt-3.5-turbo",
tools=tools,
)
print(result.choices[0].message)
result = client.chat.completions.create(
messages=[
{"role": "user", "content": "What is the weather like in Boston?"},
{
"role": "function",
"content": """{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}""",
},
{"role": "tool", "content": '{"temperature": 22, "unit": "celsius", "description": "Sunny"}'},
],
model="gpt-3.5-turbo",
tools=tools,
)
print(result.choices[0].message)