Former-commit-id: 2ae3445b0d28b4ed22ddbb2cfe09089ae0c23fe1
This commit is contained in:
hiyouga 2023-07-18 16:36:24 +08:00
parent c85a6b83b3
commit ec166e736a
3 changed files with 35 additions and 18 deletions

View File

@ -10,6 +10,8 @@ from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat.stream_chat import ChatModel
from llmtuner.api.protocol import ( from llmtuner.api.protocol import (
Role,
Finish,
ModelCard, ModelCard,
ModelList, ModelList,
ChatMessage, ChatMessage,
@ -49,12 +51,12 @@ def create_app():
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != "user": if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == "system": if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content prefix = prev_messages.pop(0).content
else: else:
prefix = None prefix = None
@ -62,7 +64,7 @@ def create_app():
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
for i in range(0, len(prev_messages), 2): for i in range(0, len(prev_messages), 2):
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant": if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream: if request.stream:
@ -81,19 +83,19 @@ def create_app():
choice_data = ChatCompletionResponseChoice( choice_data = ChatCompletionResponseChoice(
index=0, index=0,
message=ChatMessage(role="assistant", content=response), message=ChatMessage(role=Role.ASSISTANT, content=response),
finish_reason="stop" finish_reason=Finish.STOP
) )
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion") return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role="assistant"), delta=DeltaMessage(role=Role.ASSISTANT),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
@ -107,15 +109,15 @@ def create_app():
delta=DeltaMessage(content=new_text), delta=DeltaMessage(content=new_text),
finish_reason=None finish_reason=None
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(), delta=DeltaMessage(),
finish_reason="stop" finish_reason=Finish.STOP
) )
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk") chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
yield json.dumps(chunk, ensure_ascii=False) yield json.dumps(chunk, ensure_ascii=False)
yield "[DONE]" yield "[DONE]"

View File

@ -1,6 +1,18 @@
import time import time
from enum import Enum
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Literal, Optional from typing import List, Optional
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class Finish(str, Enum):
STOP = "stop"
LENGTH = "length"
class ModelCard(BaseModel): class ModelCard(BaseModel):
@ -19,12 +31,12 @@ class ModelList(BaseModel):
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system"] role: Role
content: str content: str
class DeltaMessage(BaseModel): class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None role: Optional[Role] = None
content: Optional[str] = None content: Optional[str] = None
@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel): class ChatCompletionResponseChoice(BaseModel):
index: int index: int
message: ChatMessage message: ChatMessage
finish_reason: Literal["stop", "length"] finish_reason: Finish
class ChatCompletionResponseStreamChoice(BaseModel): class ChatCompletionResponseStreamChoice(BaseModel):
index: int index: int
delta: DeltaMessage delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length"]] = None finish_reason: Optional[Finish] = None
class ChatCompletionResponseUsage(BaseModel): class ChatCompletionResponseUsage(BaseModel):
@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
class ChatCompletionResponse(BaseModel): class ChatCompletionResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion"] object: Optional[str] = "chat.completion"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseChoice] choices: List[ChatCompletionResponseChoice]
@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
class ChatCompletionStreamResponse(BaseModel): class ChatCompletionStreamResponse(BaseModel):
id: Optional[str] = "chatcmpl-default" id: Optional[str] = "chatcmpl-default"
object: Literal["chat.completion.chunk"] object: Optional[str] = "chat.completion.chunk"
created: Optional[int] = Field(default_factory=lambda: int(time.time())) created: Optional[int] = Field(default_factory=lambda: int(time.time()))
model: str model: str
choices: List[ChatCompletionResponseStreamChoice] choices: List[ChatCompletionResponseStreamChoice]

View File

@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_world_process_zero:
return
cur_time = time.time() cur_time = time.time()
cur_steps = state.log_history[-1].get("step") cur_steps = state.log_history[-1].get("step")
elapsed_time = cur_time - self.start_time elapsed_time = cur_time - self.start_time