fix chat engines

do not use pop(key, default) since api assigns None to dict values


Former-commit-id: d52fae2fa866afeb6156dc98388ce5cc6d5eca77
This commit is contained in:
hiyouga 2024-05-20 00:36:43 +08:00
parent 6955042c10
commit 864da49139
3 changed files with 50 additions and 37 deletions

View File

@ -2,7 +2,7 @@ import asyncio
import concurrent.futures import concurrent.futures
import os import os
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
@ -66,16 +66,16 @@ class HuggingfaceEngine(BaseEngine):
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
inputs = torch.tensor([prompt_ids], device=model.device) inputs = torch.tensor([prompt_ids], device=model.device)
do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"]) do_sample: Optional[bool] = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", generating_args["temperature"]) temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", generating_args["top_p"]) top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", generating_args["top_k"]) top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", 1) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"]) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"]) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
raise ValueError("Stop parameter is not supported in Huggingface engine yet.") raise ValueError("Stop parameter is not supported in Huggingface engine yet.")
@ -83,20 +83,23 @@ class HuggingfaceEngine(BaseEngine):
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(
dict( dict(
do_sample=do_sample, do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
temperature=temperature, temperature=temperature if temperature is not None else generating_args["temperature"],
top_p=top_p, top_p=top_p if top_p is not None else generating_args["top_p"],
top_k=top_k, top_k=top_k if top_k is not None else generating_args["top_k"],
num_return_sequences=num_return_sequences, num_return_sequences=num_return_sequences,
repetition_penalty=repetition_penalty, repetition_penalty=repetition_penalty
length_penalty=length_penalty, if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
) )
) )
if isinstance(num_return_sequences, int) and num_return_sequences > 1: if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0
generating_args["do_sample"] = True generating_args["do_sample"] = True
generating_args["temperature"] = generating_args["temperature"] or 1.0
if not generating_args["temperature"]: if not generating_args["temperature"]:
generating_args["do_sample"] = False generating_args["do_sample"] = False

View File

@ -1,5 +1,5 @@
import uuid import uuid
from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger from ..extras.logging import get_logger
@ -102,18 +102,25 @@ class VllmEngine(BaseEngine):
) )
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
use_beam_search = self.generating_args["num_beams"] > 1 use_beam_search: bool = self.generating_args["num_beams"] > 1
temperature = input_kwargs.pop("temperature", self.generating_args["temperature"]) temperature: Optional[float] = input_kwargs.pop("temperature", None)
top_p = input_kwargs.pop("top_p", self.generating_args["top_p"]) top_p: Optional[float] = input_kwargs.pop("top_p", None)
top_k = input_kwargs.pop("top_k", self.generating_args["top_k"]) top_k: Optional[float] = input_kwargs.pop("top_k", None)
num_return_sequences = input_kwargs.pop("num_return_sequences", 1) num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"]) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"]) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
max_length = input_kwargs.pop("max_length", None) max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens = input_kwargs.pop("max_new_tokens", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if "max_new_tokens" in self.generating_args:
max_tokens = self.generating_args["max_new_tokens"]
elif "max_length" in self.generating_args:
if self.generating_args["max_length"] > prompt_length:
max_tokens = self.generating_args["max_length"] - prompt_length
else:
max_tokens = 1
max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"]
if max_length: if max_length:
max_tokens = max_length - prompt_length if max_length > prompt_length else 1 max_tokens = max_length - prompt_length if max_length > prompt_length else 1
@ -122,12 +129,15 @@ class VllmEngine(BaseEngine):
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=num_return_sequences, n=num_return_sequences,
repetition_penalty=repetition_penalty, repetition_penalty=(
temperature=temperature, repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"]
top_p=top_p, )
top_k=top_k, or 1.0, # repetition_penalty must > 0
temperature=temperature if temperature is not None else self.generating_args["temperature"],
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
top_k=top_k if top_k is not None else self.generating_args["top_k"],
use_beam_search=use_beam_search, use_beam_search=use_beam_search,
length_penalty=length_penalty, length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"],
stop=stop, stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens, max_tokens=max_tokens,

View File

@ -68,8 +68,8 @@ class Template:
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
system: str, system: Optional[str],
tools: str, tools: Optional[str],
cutoff_len: int, cutoff_len: int,
reserved_label_len: int, reserved_label_len: int,
) -> Sequence[Tuple[List[int], List[int]]]: ) -> Sequence[Tuple[List[int], List[int]]]: