fix chat engines

do not use pop(key, default) since api assigns None to dict values


Former-commit-id: d52fae2fa866afeb6156dc98388ce5cc6d5eca77
This commit is contained in:
hiyouga 2024-05-20 00:36:43 +08:00
parent 6955042c10
commit 864da49139
3 changed files with 50 additions and 37 deletions

View File

@ -2,7 +2,7 @@ import asyncio
import concurrent.futures
import os
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch
from transformers import GenerationConfig, TextIteratorStreamer
@ -66,16 +66,16 @@ class HuggingfaceEngine(BaseEngine):
prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device)
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"])
temperature = input_kwargs.pop("temperature", generating_args["temperature"])
top_p = input_kwargs.pop("top_p", generating_args["top_p"])
top_k = input_kwargs.pop("top_k", generating_args["top_k"])
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"])
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"])
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None)
do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None:
raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
@ -83,20 +83,23 @@ class HuggingfaceEngine(BaseEngine):
generating_args = generating_args.copy()
generating_args.update(
dict(
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature if temperature is not None else generating_args["temperature"],
top_p=top_p if top_p is not None else generating_args["top_p"],
top_k=top_k if top_k is not None else generating_args["top_k"],
num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
repetition_penalty=repetition_penalty
if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id,
)
)
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
generating_args["do_sample"] = True
generating_args["temperature"] = generating_args["temperature"] or 1.0
if not generating_args["temperature"]:
generating_args["do_sample"] = False

View File

@ -1,5 +1,5 @@
import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
@ -102,18 +102,25 @@ class VllmEngine(BaseEngine):
)
prompt_length = len(prompt_ids)
use_beam_search = self.generating_args["num_beams"] > 1
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"])
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"])
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"])
num_return_sequences = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"])
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"])
max_length = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None)
use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"]
if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1
@ -122,12 +129,15 @@ class VllmEngine(BaseEngine):
sampling_params = SamplingParams(
n=num_return_sequences,
repetition_penalty=repetition_penalty,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=(
repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
)
or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search,
length_penalty=length_penalty,
length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens,

View File

@ -68,8 +68,8 @@ class Template:
self,
tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]],
system: str,
tools: str,
system: Optional[str],
tools: Optional[str],
cutoff_len: int,
reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]: