diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 670adbc0..96cde64c 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -7,5 +7,5 @@ from .train import export_model, run_exp from .webui import create_ui, create_web_demo -__version__ = "0.5.0" +__version__ = "0.5.1" __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"] diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 2147a1db..428d15de 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -85,7 +85,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": if not chat_model.can_generate: raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") - if len(request.messages) == 0 or request.messages[-1].role != Role.USER: + if len(request.messages) == 0 or request.messages[-1].role not in [Role.USER, Role.TOOL]: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") messages = [dictify(message) for message in request.messages] diff --git a/tests/test_toolcall.py b/tests/test_toolcall.py new file mode 100644 index 00000000..a26af688 --- /dev/null +++ b/tests/test_toolcall.py @@ -0,0 +1,46 @@ +import os + +from openai import OpenAI + + +os.environ["OPENAI_BASE_URL"] = "http://192.168.5.193:8000/v1" +os.environ["OPENAI_API_KEY"] = "0" + + +if __name__ == "__main__": + client = OpenAI() + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"} + }, + "required": ["location"], + }, + }, + } + ] + result = client.chat.completions.create( + messages=[{"role": "user", "content": "What is the weather like in Boston?"}], + model="gpt-3.5-turbo", + tools=tools, + ) + print(result.choices[0].message) + result = client.chat.completions.create( + messages=[ + {"role": "user", "content": "What is the weather like in Boston?"}, + { + "role": "function", + "content": """{"name": "get_current_weather", "arguments": {"location": "Boston, MA"}}""", + }, + {"role": "tool", "content": '{"temperature": 22, "unit": "celsius", "description": "Sunny"}'}, + ], + model="gpt-3.5-turbo", + tools=tools, + ) + print(result.choices[0].message)