mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
finish agent
Former-commit-id: 3e982cc7146b408b359bb38ed4aa840183b23d72
This commit is contained in:
parent
865f48f1c3
commit
1fbf83acc2
@ -7,18 +7,20 @@ from typing import Any, Dict, Sequence
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
|
from ..data import Role as DataRole
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
from ..extras.packages import is_fastapi_availble, is_starlette_available, is_uvicorn_available
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
|
ChatCompletionMessage,
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice,
|
ChatCompletionResponseChoice,
|
||||||
ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionResponseUsage,
|
ChatCompletionResponseUsage,
|
||||||
ChatCompletionStreamResponse,
|
ChatCompletionStreamResponse,
|
||||||
ChatMessage,
|
|
||||||
DeltaMessage,
|
|
||||||
Finish,
|
Finish,
|
||||||
|
Function,
|
||||||
|
FunctionCall,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelList,
|
ModelList,
|
||||||
Role,
|
Role,
|
||||||
@ -84,7 +86,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed")
|
||||||
|
|
||||||
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
if len(request.messages) == 0 or request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
|
|
||||||
messages = [dictify(message) for message in request.messages]
|
messages = [dictify(message) for message in request.messages]
|
||||||
if len(messages) and messages[0]["role"] == Role.SYSTEM:
|
if len(messages) and messages[0]["role"] == Role.SYSTEM:
|
||||||
@ -96,16 +98,21 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||||
|
|
||||||
for i in range(len(messages)):
|
for i in range(len(messages)):
|
||||||
if messages[i]["role"] == Role.USER:
|
if i % 2 == 0 and messages[i]["role"] not in [Role.USER, Role.TOOL]:
|
||||||
if i % 2 == 1:
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
elif i % 2 == 1 and messages[i]["role"] not in [Role.ASSISTANT, Role.FUNCTION]:
|
||||||
elif messages[i]["role"] == Role.ASSISTANT:
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
|
||||||
if i % 2 == 0:
|
elif messages[i]["role"] == Role.TOOL:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
messages[i]["role"] = DataRole.OBSERVATION
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
tools = "" # TODO: add tools
|
tool_list = request.tools
|
||||||
|
if len(tool_list):
|
||||||
|
try:
|
||||||
|
tools = json.dumps([tool_list[0]["function"]], ensure_ascii=False)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
|
||||||
|
else:
|
||||||
|
tools = ""
|
||||||
|
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
@ -130,12 +137,24 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
prompt_length, response_length = 0, 0
|
prompt_length, response_length = 0, 0
|
||||||
choices = []
|
choices = []
|
||||||
for i, response in enumerate(responses):
|
for i, response in enumerate(responses):
|
||||||
choices.append(
|
if tools:
|
||||||
ChatCompletionResponseChoice(
|
result = chat_model.template.format_tools.extract(response.response_text)
|
||||||
index=i,
|
else:
|
||||||
message=ChatMessage(role=Role.ASSISTANT, content=response.response_text),
|
result = response.response_text
|
||||||
finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH,
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
name, arguments = result
|
||||||
|
function = Function(name=name, arguments=arguments)
|
||||||
|
response_message = ChatCompletionMessage(
|
||||||
|
role=Role.ASSISTANT, tool_calls=[FunctionCall(function=function)]
|
||||||
)
|
)
|
||||||
|
finish_reason = Finish.TOOL
|
||||||
|
else:
|
||||||
|
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result)
|
||||||
|
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH
|
||||||
|
|
||||||
|
choices.append(
|
||||||
|
ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)
|
||||||
)
|
)
|
||||||
prompt_length = response.prompt_length
|
prompt_length = response.prompt_length
|
||||||
response_length += response.response_length
|
response_length += response.response_length
|
||||||
@ -152,7 +171,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest
|
||||||
):
|
):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0, delta=DeltaMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
index=0, delta=ChatCompletionMessage(role=Role.ASSISTANT, content=""), finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield jsonify(chunk)
|
yield jsonify(chunk)
|
||||||
@ -170,12 +189,14 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0, delta=DeltaMessage(content=new_text), finish_reason=None
|
index=0, delta=ChatCompletionMessage(content=new_text), finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield jsonify(chunk)
|
yield jsonify(chunk)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason=Finish.STOP)
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0, delta=ChatCompletionMessage(), finish_reason=Finish.STOP
|
||||||
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield jsonify(chunk)
|
yield jsonify(chunk)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
@ -11,12 +11,15 @@ class Role(str, Enum):
|
|||||||
USER = "user"
|
USER = "user"
|
||||||
ASSISTANT = "assistant"
|
ASSISTANT = "assistant"
|
||||||
SYSTEM = "system"
|
SYSTEM = "system"
|
||||||
|
FUNCTION = "function"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
class Finish(str, Enum):
|
class Finish(str, Enum):
|
||||||
STOP = "stop"
|
STOP = "stop"
|
||||||
LENGTH = "length"
|
LENGTH = "length"
|
||||||
|
TOOL = "tool_calls"
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
@ -31,19 +34,32 @@ class ModelList(BaseModel):
|
|||||||
data: List[ModelCard] = []
|
data: List[ModelCard] = []
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall(BaseModel):
|
||||||
|
id: Literal["call_default"] = "call_default"
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Role
|
role: Role
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: Optional[Role] = None
|
role: Optional[Role] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[FunctionCall]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
|
tools: Optional[list] = []
|
||||||
do_sample: bool = True
|
do_sample: bool = True
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
@ -54,13 +70,13 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
|
|
||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatCompletionMessage
|
||||||
finish_reason: Finish
|
finish_reason: Finish
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: ChatCompletionMessage
|
||||||
finish_reason: Optional[Finish] = None
|
finish_reason: Optional[Finish] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,9 +37,9 @@ class ChatModel:
|
|||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Tuple[Dict[str, Any], int]:
|
) -> Tuple[Dict[str, Any], int]:
|
||||||
new_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
prompt, _ = self.template.encode_oneturn(
|
prompt, _ = self.template.encode_oneturn(
|
||||||
tokenizer=self.tokenizer, messages=new_messages, system=system, tools=tools
|
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
|
||||||
)
|
)
|
||||||
prompt_length = len(prompt)
|
prompt_length = len(prompt)
|
||||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||||
|
@ -74,6 +74,9 @@ class Formatter(ABC):
|
|||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def extract(self, content: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmptyFormatter(Formatter):
|
class EmptyFormatter(Formatter):
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
import json
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
|
from ..data import Role
|
||||||
from ..extras.misc import torch_gc
|
from ..extras.misc import torch_gc
|
||||||
from ..hparams import GeneratingArguments
|
from ..hparams import GeneratingArguments
|
||||||
from .common import get_save_dir
|
from .common import get_save_dir
|
||||||
@ -105,22 +107,37 @@ class WebChatModel(ChatModel):
|
|||||||
self,
|
self,
|
||||||
chatbot: List[Tuple[str, str]],
|
chatbot: List[Tuple[str, str]],
|
||||||
query: str,
|
query: str,
|
||||||
history: List[Tuple[str, str]],
|
messages: Sequence[Tuple[str, str]],
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
) -> Generator[Tuple[Sequence[Tuple[str, str]], Sequence[Tuple[str, str]]], None, None]:
|
||||||
chatbot.append([query, ""])
|
chatbot.append([query, ""])
|
||||||
|
query_messages = messages + [{"role": Role.USER, "content": query}]
|
||||||
response = ""
|
response = ""
|
||||||
for new_text in self.stream_chat(
|
for new_text in self.stream_chat(
|
||||||
query, history, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
query_messages, system, tools, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
new_history = history + [(query, response)]
|
if tools:
|
||||||
chatbot[-1] = [query, self.postprocess(response)]
|
result = self.template.format_tools.extract(response)
|
||||||
yield chatbot, new_history
|
else:
|
||||||
|
result = response
|
||||||
|
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
name, arguments = result
|
||||||
|
arguments = json.loads(arguments)
|
||||||
|
tool_call = json.dumps({"name": name, "arguments": arguments}, ensure_ascii=False)
|
||||||
|
output_messages = query_messages + [{"role": Role.FUNCTION, "content": tool_call}]
|
||||||
|
bot_text = "```json\n" + tool_call + "\n```"
|
||||||
|
else:
|
||||||
|
output_messages = query_messages + [{"role": Role.ASSISTANT, "content": result}]
|
||||||
|
bot_text = result
|
||||||
|
|
||||||
|
chatbot[-1] = [query, self.postprocess(bot_text)]
|
||||||
|
yield chatbot, output_messages
|
||||||
|
|
||||||
def postprocess(self, response: str) -> str:
|
def postprocess(self, response: str) -> str:
|
||||||
blocks = response.split("```")
|
blocks = response.split("```")
|
||||||
|
@ -17,7 +17,7 @@ def create_chat_box(
|
|||||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||||
with gr.Box(visible=visible) as chat_box:
|
with gr.Box(visible=visible) as chat_box:
|
||||||
chatbot = gr.Chatbot()
|
chatbot = gr.Chatbot()
|
||||||
history = gr.State([])
|
messages = gr.State([])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
system = gr.Textbox(show_label=False)
|
system = gr.Textbox(show_label=False)
|
||||||
@ -32,21 +32,21 @@ def create_chat_box(
|
|||||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||||
|
|
||||||
tools.input(check_json_schema, [tools])
|
tools.input(check_json_schema, [tools, engine.manager.get_elem_by_name("top.lang")])
|
||||||
|
|
||||||
submit_btn.click(
|
submit_btn.click(
|
||||||
engine.chatter.predict,
|
engine.chatter.predict,
|
||||||
[chatbot, query, history, system, tools, max_new_tokens, top_p, temperature],
|
[chatbot, query, messages, system, tools, max_new_tokens, top_p, temperature],
|
||||||
[chatbot, history],
|
[chatbot, messages],
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
).then(lambda: gr.update(value=""), outputs=[query])
|
).then(lambda: gr.update(value=""), outputs=[query])
|
||||||
|
|
||||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages], show_progress=True)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
chat_box,
|
chat_box,
|
||||||
chatbot,
|
chatbot,
|
||||||
history,
|
messages,
|
||||||
dict(
|
dict(
|
||||||
system=system,
|
system=system,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -208,6 +208,8 @@ ALERTS = {
|
|||||||
"zh": "展示模式不支持训练,请先复制到私人空间。",
|
"zh": "展示模式不支持训练,请先复制到私人空间。",
|
||||||
},
|
},
|
||||||
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
|
"err_device_count": {"en": "Multiple GPUs are not supported yet.", "zh": "尚不支持多 GPU 训练。"},
|
||||||
|
"err_tool_name": {"en": "Tool name not found.", "zh": "工具名称未找到。"},
|
||||||
|
"err_json_schema": {"en": "Invalid JSON schema.", "zh": "Json 格式错误。"},
|
||||||
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
|
"info_aborting": {"en": "Aborted, wait for terminating...", "zh": "训练中断,正在等待线程结束……"},
|
||||||
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
|
"info_aborted": {"en": "Ready.", "zh": "准备就绪。"},
|
||||||
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},
|
"info_finished": {"en": "Finished.", "zh": "训练完毕。"},
|
||||||
|
@ -8,6 +8,7 @@ import gradio as gr
|
|||||||
from ..extras.packages import is_matplotlib_available
|
from ..extras.packages import is_matplotlib_available
|
||||||
from ..extras.ploting import smooth
|
from ..extras.ploting import smooth
|
||||||
from .common import get_save_dir
|
from .common import get_save_dir
|
||||||
|
from .locales import ALERTS
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -40,11 +41,15 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
|||||||
return gr.update(interactive=True)
|
return gr.update(interactive=True)
|
||||||
|
|
||||||
|
|
||||||
def check_json_schema(text: str) -> None:
|
def check_json_schema(text: str, lang: str) -> None:
|
||||||
try:
|
try:
|
||||||
json.loads(text)
|
tools = json.loads(text)
|
||||||
|
for tool in tools:
|
||||||
|
assert "name" in tool
|
||||||
|
except AssertionError:
|
||||||
|
gr.Warning(ALERTS["err_tool_name"][lang])
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
gr.Warning("Invalid JSON schema")
|
gr.Warning(ALERTS["err_json_schema"][lang])
|
||||||
|
|
||||||
|
|
||||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user