mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 22:02:51 +08:00
"add support for vllm api stop parameter"
Former-commit-id: 42edc81585bc7170ab4e4871ad12094079e89bc9
This commit is contained in:
parent
35917001b1
commit
4a0aab86f1
@ -141,6 +141,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n,
|
||||
stop=request.stop
|
||||
)
|
||||
|
||||
prompt_length, response_length = 0, 0
|
||||
@ -193,6 +194,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
stop=request.stop
|
||||
):
|
||||
if len(new_token) == 0:
|
||||
continue
|
||||
|
@ -1,6 +1,6 @@
|
||||
import time
|
||||
from enum import Enum, unique
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Literal
|
||||
@ -78,6 +78,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
n: int = 1
|
||||
max_tokens: Optional[int] = None
|
||||
stream: bool = False
|
||||
stop: Union[Optional[str], List[str]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
|
@ -96,6 +96,7 @@ class VllmEngine(BaseEngine):
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
stop = input_kwargs.pop("stop", None)
|
||||
|
||||
generating_args = self.generating_args.copy()
|
||||
generating_args.update(
|
||||
@ -105,6 +106,7 @@ class VllmEngine(BaseEngine):
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
stop=stop or generating_args["stop"]
|
||||
)
|
||||
)
|
||||
|
||||
@ -125,6 +127,7 @@ class VllmEngine(BaseEngine):
|
||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
max_tokens=generating_args["max_new_tokens"],
|
||||
skip_special_tokens=True,
|
||||
stop=generating_args["stop"],
|
||||
)
|
||||
|
||||
if self.processor is not None and image is not None:
|
||||
|
@ -1,5 +1,5 @@
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Union, Optional, List
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -46,7 +46,10 @@ class GeneratingArguments:
|
||||
default=1.0,
|
||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||
)
|
||||
|
||||
stop: Union[Optional[str], List[str]] = field(
|
||||
default=None,
|
||||
metadata={"help": "List of strings or string that stop the generation when they are generated. The returned output will not contain the stop strings."},
|
||||
)
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", -1) > 0:
|
||||
|
Loading…
x
Reference in New Issue
Block a user