diff --git a/src/api_demo.py b/src/api_demo.py index b05157a1..85c41eae 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -49,6 +49,7 @@ class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] temperature: Optional[float] = None top_p: Optional[float] = None + max_length: Optional[int] = None max_new_tokens: Optional[int] = None stream: Optional[bool] = False @@ -100,9 +101,14 @@ async def create_chat_completion(request: ChatCompletionRequest): "input_ids": inputs["input_ids"], "temperature": request.temperature if request.temperature else gen_kwargs["temperature"], "top_p": request.top_p if request.top_p else gen_kwargs["top_p"], - "max_new_tokens": request.max_new_tokens if request.max_new_tokens else gen_kwargs["max_new_tokens"], "logits_processor": get_logits_processor() }) + if request.max_length: + gen_kwargs.pop("max_new_tokens", None) + gen_kwargs["max_length"] = request.max_length + if request.max_new_tokens: + gen_kwargs.pop("max_length", None) + gen_kwargs["max_new_tokens"] = request.max_new_tokens if request.stream: generate = predict(gen_kwargs, request.model) diff --git a/src/utils/common.py b/src/utils/common.py index 35d17c5b..93c9a2ce 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -171,8 +171,8 @@ def load_pretrained( padding_side="left", **config_kwargs ) - tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token - tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version) + if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version) + tokenizer.pad_token_id = 0 # set as the token config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) is_mergeable = True diff --git a/src/utils/config.py b/src/utils/config.py index dfe09392..ffd4ee7a 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -277,6 +277,10 @@ class GeneratingArguments: default=1, metadata={"help": "Number of beams for beam search. 1 means no beam search."} ) + max_length: Optional[int] = field( + default=None, + metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} + ) max_new_tokens: Optional[int] = field( default=512, metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} @@ -291,4 +295,7 @@ class GeneratingArguments: ) def to_dict(self) -> Dict[str, Any]: - return asdict(self) + args = asdict(self) + if args.get("max_new_tokens", None): + args.pop("max_length", None) + return args