mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
parent
b7468ea0a8
commit
34f1de0574
@ -264,8 +264,8 @@ huggingface-cli login
|
|||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.10 |
|
| python | 3.8 | 3.10 |
|
||||||
| torch | 1.13.1 | 2.2.0 |
|
| torch | 1.13.1 | 2.2.0 |
|
||||||
| transformers | 4.37.2 | 4.39.1 |
|
| transformers | 4.37.2 | 4.39.2 |
|
||||||
| datasets | 2.14.3 | 2.17.1 |
|
| datasets | 2.14.3 | 2.18.0 |
|
||||||
| accelerate | 0.27.2 | 0.28.0 |
|
| accelerate | 0.27.2 | 0.28.0 |
|
||||||
| peft | 0.9.0 | 0.10.0 |
|
| peft | 0.9.0 | 0.10.0 |
|
||||||
| trl | 0.8.1 | 0.8.1 |
|
| trl | 0.8.1 | 0.8.1 |
|
||||||
|
@ -264,8 +264,8 @@ huggingface-cli login
|
|||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.10 |
|
| python | 3.8 | 3.10 |
|
||||||
| torch | 1.13.1 | 2.2.0 |
|
| torch | 1.13.1 | 2.2.0 |
|
||||||
| transformers | 4.37.2 | 4.39.1 |
|
| transformers | 4.37.2 | 4.39.2 |
|
||||||
| datasets | 2.14.3 | 2.17.1 |
|
| datasets | 2.14.3 | 2.18.0 |
|
||||||
| accelerate | 0.27.2 | 0.28.0 |
|
| accelerate | 0.27.2 | 0.28.0 |
|
||||||
| peft | 0.9.0 | 0.10.0 |
|
| peft | 0.9.0 | 0.10.0 |
|
||||||
| trl | 0.8.1 | 0.8.1 |
|
| trl | 0.8.1 | 0.8.1 |
|
||||||
|
@ -108,12 +108,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
|
|
||||||
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
|
||||||
|
name = message.tool_calls[0].function.name
|
||||||
|
arguments = message.tool_calls[0].function.arguments
|
||||||
|
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
|
||||||
|
input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
|
||||||
|
else:
|
||||||
|
input_messages.append({"role": role_mapping[message.role], "content": message.content})
|
||||||
|
|
||||||
tool_list = request.tools
|
tool_list = request.tools
|
||||||
if isinstance(tool_list, list) and len(tool_list):
|
if isinstance(tool_list, list) and len(tool_list):
|
||||||
try:
|
try:
|
||||||
tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
|
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||||
else:
|
else:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
@ -39,6 +39,17 @@ class Function(BaseModel):
|
|||||||
arguments: str
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionDefinition(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionAvailable(BaseModel):
|
||||||
|
type: Literal["function", "code_interpreter"] = "function"
|
||||||
|
function: Optional[FunctionDefinition] = None
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
id: Literal["call_default"] = "call_default"
|
id: Literal["call_default"] = "call_default"
|
||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
@ -47,7 +58,8 @@ class FunctionCall(BaseModel):
|
|||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Role
|
role: Role
|
||||||
content: str
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[FunctionCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
@ -59,7 +71,7 @@ class ChatCompletionMessage(BaseModel):
|
|||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
tools: list = []
|
tools: Optional[List[FunctionAvailable]] = None
|
||||||
do_sample: bool = True
|
do_sample: bool = True
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
|
@ -193,6 +193,6 @@ def llama_flash_attn_forward(
|
|||||||
|
|
||||||
|
|
||||||
def apply_llama_patch() -> None:
|
def apply_llama_patch() -> None:
|
||||||
require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1")
|
require_version("transformers==4.39.2", "To fix: pip install transformers==4.39.2")
|
||||||
LlamaAttention.forward = llama_torch_attn_forward
|
LlamaAttention.forward = llama_torch_attn_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
LlamaFlashAttention2.forward = llama_flash_attn_forward
|
||||||
|
@ -331,7 +331,7 @@ def patch_model(
|
|||||||
):
|
):
|
||||||
gen_config.do_sample = True
|
gen_config.do_sample = True
|
||||||
|
|
||||||
if model_args.resize_vocab:
|
if is_trainable and model_args.resize_vocab:
|
||||||
_resize_embedding_layer(model, tokenizer)
|
_resize_embedding_layer(model, tokenizer)
|
||||||
|
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
|
@ -15,7 +15,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
|
|||||||
for grade, hour in zip(grades, hours):
|
for grade, hour in zip(grades, hours):
|
||||||
total_score += grade_to_score[grade] * hour
|
total_score += grade_to_score[grade] * hour
|
||||||
total_hour += hour
|
total_hour += hour
|
||||||
return total_score / total_hour
|
return round(total_score / total_hour, 2)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -45,16 +45,19 @@ def main():
|
|||||||
messages = []
|
messages = []
|
||||||
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
|
messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
|
||||||
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
||||||
|
if result.choices[0].message.tool_calls is None:
|
||||||
|
raise ValueError("Cannot retrieve function call from the response.")
|
||||||
|
|
||||||
|
messages.append(result.choices[0].message)
|
||||||
tool_call = result.choices[0].message.tool_calls[0].function
|
tool_call = result.choices[0].message.tool_calls[0].function
|
||||||
|
print(tool_call)
|
||||||
|
# Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa')
|
||||||
name, arguments = tool_call.name, json.loads(tool_call.arguments)
|
name, arguments = tool_call.name, json.loads(tool_call.arguments)
|
||||||
messages.append(
|
|
||||||
{"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)}
|
|
||||||
)
|
|
||||||
tool_result = tool_map[name](**arguments)
|
tool_result = tool_map[name](**arguments)
|
||||||
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
|
messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
|
||||||
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
result = client.chat.completions.create(messages=messages, model="test", tools=tools)
|
||||||
print(result.choices[0].message.content)
|
print(result.choices[0].message.content)
|
||||||
# Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
|
# Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42.
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user