diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 3f06fef1..892bf901 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -141,6 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": top_p=request.top_p, max_new_tokens=request.max_tokens, num_return_sequences=request.n, + stop=request.stop ) prompt_length, response_length = 0, 0 @@ -193,6 +194,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens, + stop=request.stop ): if len(new_token) == 0: continue diff --git a/src/llmtuner/api/protocol.py b/src/llmtuner/api/protocol.py index ece2132b..8f1b7b4c 100644 --- a/src/llmtuner/api/protocol.py +++ b/src/llmtuner/api/protocol.py @@ -1,6 +1,6 @@ import time from enum import Enum, unique -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field from typing_extensions import Literal @@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel): n: int = 1 max_tokens: Optional[int] = None stream: bool = False + stop: Union[Optional[str], List[str]] = None class ChatCompletionResponseChoice(BaseModel): diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index 0f0dc366..9863d635 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -96,6 +96,7 @@ class VllmEngine(BaseEngine): repetition_penalty = input_kwargs.pop("repetition_penalty", None) max_length = input_kwargs.pop("max_length", None) max_new_tokens = input_kwargs.pop("max_new_tokens", None) + stop = input_kwargs.pop("stop", None) generating_args = self.generating_args.copy() generating_args.update( @@ -105,6 +106,7 @@ class VllmEngine(BaseEngine): top_k=top_k or generating_args["top_k"], num_return_sequences=num_return_sequences or 1, repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], + stop=stop or generating_args["stop"] ) ) @@ -125,6 +127,7 @@ class VllmEngine(BaseEngine): stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, max_tokens=generating_args["max_new_tokens"], skip_special_tokens=True, + stop=generating_args["stop"], ) if self.processor is not None and image is not None: diff --git a/src/llmtuner/hparams/generating_args.py b/src/llmtuner/hparams/generating_args.py index e792c003..03e760e7 100644 --- a/src/llmtuner/hparams/generating_args.py +++ b/src/llmtuner/hparams/generating_args.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Union, Optional, List @dataclass @@ -46,7 +46,10 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) - + stop: Union[Optional[str], List[str]] = field( + default=None, + metadata={"help": "List of strings or string that stop the generation when they are generated. The returned output will not contain the stop strings."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) if args.get("max_new_tokens", -1) > 0: