From 4a0aab86f1fb469e114dcc23d21b6074106e10d0 Mon Sep 17 00:00:00 2001 From: zhaonx <953608703@qq.com> Date: Tue, 30 Apr 2024 17:17:09 +0800 Subject: [PATCH 1/5] "add support for vllm api stop parameter" Former-commit-id: 42edc81585bc7170ab4e4871ad12094079e89bc9 --- src/llmtuner/api/app.py | 2 ++ src/llmtuner/api/protocol.py | 3 ++- src/llmtuner/chat/vllm_engine.py | 3 +++ src/llmtuner/hparams/generating_args.py | 7 +++++-- 4 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 3f06fef1..892bf901 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -141,6 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": top_p=request.top_p, max_new_tokens=request.max_tokens, num_return_sequences=request.n, + stop=request.stop ) prompt_length, response_length = 0, 0 @@ -193,6 +194,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens, + stop=request.stop ): if len(new_token) == 0: continue diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index ece2132b..8f1b7b4c 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -1,6 +1,6 @@ import time from enum import Enum, unique -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field from typing_extensions import Literal @@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel): n: int = 1 max_tokens: Optional[int] = None stream: bool = False + stop: Union[Optional[str], List[str]] = None class ChatCompletionResponseChoice(BaseModel): diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index 0f0dc366..9863d635 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -96,6 +96,7 @@ class VllmEngine(BaseEngine): repetition_penalty = input_kwargs.pop("repetition_penalty", None) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) + stop = input_kwargs.pop("stop", None) generating_args = self.generating_args.copy() generating_args.update( @@ -105,6 +106,7 @@ class VllmEngine(BaseEngine): top_k=top_k or generating_args["top_k"], num_return_sequences=num_return_sequences or 1, repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + stop=stop or generating_args["stop"] ) ) @@ -125,6 +127,7 @@ class VllmEngine(BaseEngine): stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, max_tokens=generating_args["max_new_tokens"], skip_special_tokens=True, + stop=generating_args["stop"], ) if self.processor is not None and image is not None: diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index e792c003..03e760e7 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Union, Optional, List @dataclass @@ -46,7 +46,10 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) - + stop: Union[Optional[str], List[str]] = field( + default=None, + metadata={"help": "List of strings or string that stop the generation when they are generated. The returned output will not contain the stop strings."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) if args.get("max_new_tokens", -1) > 0: From 189346188be36b15c5043d3fa35d3c16b8f263bc Mon Sep 17 00:00:00 2001 From: zhaonx96 <953608703@qq,com> Date: Mon, 6 May 2024 10:10:00 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E2=80=9Dadd=20stop=20parameter=20in=20chat?= =?UTF-8?q?.py=E2=80=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Former-commit-id: 80645751bc4db20dbadb53950fe35af8b67eec41 --- src/llmtuner/api/chat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/llmtuner/api/chat.py b/src/llmtuner/api/chat.py index fa2f0d03..972ee906 100644 --- a/src/llmtuner/api/chat.py +++ b/src/llmtuner/api/chat.py @@ -103,6 +103,7 @@ async def create_chat_completion_response( top_p=request.top_p, max_new_tokens=request.max_tokens, num_return_sequences=request.n, + stop=request.stop ) prompt_length, response_length = 0, 0 @@ -155,6 +156,7 @@ async def create_stream_chat_completion_response( temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens, + stop=request.stop ): if len(new_token) != 0: yield _create_stream_chat_completion_chunk( From d65b2332cf5fa95be8a03ee9e4362c6e116a34f1 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 7 May 2024 00:27:56 +0800 Subject: [PATCH 3/5] Update generating_args.py Former-commit-id: 7ae7ae64f0e9e8661f9efd30997f8b96673d467a --- src/llmtuner/hparams/generating_args.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index 03e760e7..e3e196e9 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict, Union, Optional, List +from typing import Any, Dict @dataclass @@ -46,10 +46,6 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) - stop: Union[Optional[str], List[str]] = field( - default=None, - metadata={"help": "List of strings or string that stop the generation when they are generated. The returned output will not contain the stop strings."}, - ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) if args.get("max_new_tokens", -1) > 0: From 4c911044717d9b4a599d971713b37adf07dc45d2 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 7 May 2024 00:28:16 +0800 Subject: [PATCH 4/5] Update generating_args.py Former-commit-id: f32eefae3d20bb8482704daf2a0c5743452f2ce7 --- src/llmtuner/hparams/generating_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index e3e196e9..e792c003 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -46,6 +46,7 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) + def to_dict(self) -> Dict[str, Any]: args = asdict(self) if args.get("max_new_tokens", -1) > 0: From df66b288a2815a839963f71dc5ffe5a95a80cb8d Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 7 May 2024 00:37:05 +0800 Subject: [PATCH 5/5] Update vllm_engine.py Former-commit-id: 17d0005b8cb9cf75b8247bcdf4ce022e1a5afd0b --- src/llmtuner/chat/vllm_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index 9863d635..d50e41aa 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -106,7 +106,6 @@ class VllmEngine(BaseEngine): top_k=top_k or generating_args["top_k"], num_return_sequences=num_return_sequences or 1, repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], - stop=stop or generating_args["stop"] ) ) @@ -124,10 +123,10 @@ class VllmEngine(BaseEngine): top_k=generating_args["top_k"], use_beam_search=generating_args["num_beams"] > 1, length_penalty=generating_args["length_penalty"], + stop=stop, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, max_tokens=generating_args["max_new_tokens"], skip_special_tokens=True, - stop=generating_args["stop"], ) if self.processor is not None and image is not None: