Former-commit-id: cadeac0f4475719454d18a8ed12f7573f115a2c4
This commit is contained in:
hiyouga 2023-07-18 16:36:24 +08:00
parent 92c393dc2a
commit 77c4bb2985
3 changed files with 35 additions and 18 deletions

View File

@ -10,6 +10,8 @@ from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat.stream_chat import ChatModel
from llmtuner.api.protocol import ( from llmtuner.api.protocol import (
Role,
Finish,
ModelCard, ModelCard,
ModelList, ModelList,
ChatMessage, ChatMessage,
@ -49,12 +51,12 @@ def create_app():
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != "user": if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system": if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content prefix = prev_messages.pop(0).content
else: else:
prefix = None prefix = None
@ -62,7 +64,7 @@ def create_app():
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2): for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream: if request.stream:
@ -81,19 +83,19 @@ def create_app():
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage(role="assistant", content=response), message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason="stop" finish_reason=Finish.STOP
) )
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
@ -107,15 +109,15 @@ def create_app():
delta=DeltaMessage(content=new_text), delta=DeltaMessage(content=new_text),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(), delta=DeltaMessage(),
finish_reason="stop" finish_reason=Finish.STOP
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
yield "[DONE]" yield "[DONE]"

View File

@ -1,6 +1,18 @@
import time import time
from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Literal, Optional from typing import List, Optional
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
class ModelCard(BaseModel): class ModelCard(BaseModel):
@ -19,12 +31,12 @@ class ModelList(BaseModel):
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"] role: Role
content: str content: str
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
finish_reason: Literal["stop", "length"] finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Finish] = None
class ChatCompletionResponseUsage(BaseModel): class ChatCompletionResponseUsage(BaseModel):
@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion"] object: Optional[str] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion.chunk"] object: Optional[str] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]

View File

@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_world_process_zero:
return
cur_time = time.time() cur_time = time.time()
cur_steps = state.log_history[-1].get("step") cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time elapsed_time = cur_time - self.start_time