mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
parent
92c393dc2a
commit
77c4bb2985
@ -10,6 +10,8 @@ from llmtuner.tuner import get_infer_args
|
|||||||
from llmtuner.extras.misc import torch_gc
|
from llmtuner.extras.misc import torch_gc
|
||||||
from llmtuner.chat.stream_chat import ChatModel
|
from llmtuner.chat.stream_chat import ChatModel
|
||||||
from llmtuner.api.protocol import (
|
from llmtuner.api.protocol import (
|
||||||
|
Role,
|
||||||
|
Finish,
|
||||||
ModelCard,
|
ModelCard,
|
||||||
ModelList,
|
ModelList,
|
||||||
ChatMessage,
|
ChatMessage,
|
||||||
@ -49,12 +51,12 @@ def create_app():
|
|||||||
|
|
||||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if request.messages[-1].role != "user":
|
if request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=400, detail="Invalid request")
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
query = request.messages[-1].content
|
query = request.messages[-1].content
|
||||||
|
|
||||||
prev_messages = request.messages[:-1]
|
prev_messages = request.messages[:-1]
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == "system":
|
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||||
prefix = prev_messages.pop(0).content
|
prefix = prev_messages.pop(0).content
|
||||||
else:
|
else:
|
||||||
prefix = None
|
prefix = None
|
||||||
@ -62,7 +64,7 @@ def create_app():
|
|||||||
history = []
|
history = []
|
||||||
if len(prev_messages) % 2 == 0:
|
if len(prev_messages) % 2 == 0:
|
||||||
for i in range(0, len(prev_messages), 2):
|
for i in range(0, len(prev_messages), 2):
|
||||||
if prev_messages[i].role == "user" and prev_messages[i+1].role == "assistant":
|
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
@ -81,19 +83,19 @@ def create_app():
|
|||||||
|
|
||||||
choice_data = ChatCompletionResponseChoice(
|
choice_data = ChatCompletionResponseChoice(
|
||||||
index=0,
|
index=0,
|
||||||
message=ChatMessage(role="assistant", content=response),
|
message=ChatMessage(role=Role.ASSISTANT, content=response),
|
||||||
finish_reason="stop"
|
finish_reason=Finish.STOP
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage, object="chat.completion")
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||||
|
|
||||||
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(role="assistant"),
|
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield json.dumps(chunk, ensure_ascii=False)
|
||||||
|
|
||||||
for new_text in chat_model.stream_chat(
|
for new_text in chat_model.stream_chat(
|
||||||
@ -107,15 +109,15 @@ def create_app():
|
|||||||
delta=DeltaMessage(content=new_text),
|
delta=DeltaMessage(content=new_text),
|
||||||
finish_reason=None
|
finish_reason=None
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield json.dumps(chunk, ensure_ascii=False)
|
||||||
|
|
||||||
choice_data = ChatCompletionResponseStreamChoice(
|
choice_data = ChatCompletionResponseStreamChoice(
|
||||||
index=0,
|
index=0,
|
||||||
delta=DeltaMessage(),
|
delta=DeltaMessage(),
|
||||||
finish_reason="stop"
|
finish_reason=Finish.STOP
|
||||||
)
|
)
|
||||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data], object="chat.completion.chunk")
|
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||||
yield json.dumps(chunk, ensure_ascii=False)
|
yield json.dumps(chunk, ensure_ascii=False)
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
|
||||||
|
@ -1,6 +1,18 @@
|
|||||||
import time
|
import time
|
||||||
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Role(str, Enum):
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
SYSTEM = "system"
|
||||||
|
|
||||||
|
|
||||||
|
class Finish(str, Enum):
|
||||||
|
STOP = "stop"
|
||||||
|
LENGTH = "length"
|
||||||
|
|
||||||
|
|
||||||
class ModelCard(BaseModel):
|
class ModelCard(BaseModel):
|
||||||
@ -19,12 +31,12 @@ class ModelList(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Literal["user", "assistant", "system"]
|
role: Role
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
class DeltaMessage(BaseModel):
|
class DeltaMessage(BaseModel):
|
||||||
role: Optional[Literal["user", "assistant", "system"]] = None
|
role: Optional[Role] = None
|
||||||
content: Optional[str] = None
|
content: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@ -41,13 +53,13 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
class ChatCompletionResponseChoice(BaseModel):
|
class ChatCompletionResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
finish_reason: Literal["stop", "length"]
|
finish_reason: Finish
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
finish_reason: Optional[Finish] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseUsage(BaseModel):
|
class ChatCompletionResponseUsage(BaseModel):
|
||||||
@ -58,7 +70,7 @@ class ChatCompletionResponseUsage(BaseModel):
|
|||||||
|
|
||||||
class ChatCompletionResponse(BaseModel):
|
class ChatCompletionResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Optional[str] = "chatcmpl-default"
|
||||||
object: Literal["chat.completion"]
|
object: Optional[str] = "chat.completion"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseChoice]
|
choices: List[ChatCompletionResponseChoice]
|
||||||
@ -67,7 +79,7 @@ class ChatCompletionResponse(BaseModel):
|
|||||||
|
|
||||||
class ChatCompletionStreamResponse(BaseModel):
|
class ChatCompletionStreamResponse(BaseModel):
|
||||||
id: Optional[str] = "chatcmpl-default"
|
id: Optional[str] = "chatcmpl-default"
|
||||||
object: Literal["chat.completion.chunk"]
|
object: Optional[str] = "chat.completion.chunk"
|
||||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
choices: List[ChatCompletionResponseStreamChoice]
|
choices: List[ChatCompletionResponseStreamChoice]
|
||||||
|
@ -47,6 +47,9 @@ class LogCallback(TrainerCallback):
|
|||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
"""
|
"""
|
||||||
|
if not state.is_world_process_zero:
|
||||||
|
return
|
||||||
|
|
||||||
cur_time = time.time()
|
cur_time = time.time()
|
||||||
cur_steps = state.log_history[-1].get("step")
|
cur_steps = state.log_history[-1].get("step")
|
||||||
elapsed_time = cur_time - self.start_time
|
elapsed_time = cur_time - self.start_time
|
||||||
|
Loading…
x
Reference in New Issue
Block a user