diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 41a7fe9a..fdde591c 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -68,7 +68,11 @@ def create_app(chat_model: ChatModel) -> FastAPI: return EventSourceResponse(generate, media_type="text/event-stream") response, (prompt_length, response_length) = chat_model.chat( - query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + query, history, system, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens ) usage = ChatCompletionResponseUsage( @@ -95,7 +99,11 @@ def create_app(chat_model: ChatModel) -> FastAPI: yield chunk.json(exclude_unset=True, ensure_ascii=False) for new_text in chat_model.stream_chat( - query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens + query, history, system, + do_sample=request.do_sample, + temperature=request.temperature, + top_p=request.top_p, + max_new_tokens=request.max_tokens ): if len(new_text) == 0: continue diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index cba0b6a6..1412af5f 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -43,6 +43,7 @@ class DeltaMessage(BaseModel): class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] + do_sample: Optional[bool] = True temperature: Optional[float] = None top_p: Optional[float] = None n: Optional[int] = 1