fix hf chat engine

Former-commit-id: 97346c1d3d87f0bd5ddcd70ff485f6a8273244aa
This commit is contained in:
hiyouga 2024-05-29 01:20:07 +08:00
parent 13e7b64641
commit 05277ee864

View File

@ -8,6 +8,7 @@ import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -23,6 +24,9 @@ if TYPE_CHECKING:
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
class HuggingfaceEngine(BaseEngine): class HuggingfaceEngine(BaseEngine):
def __init__( def __init__(
self, self,
@ -92,7 +96,7 @@ class HuggingfaceEngine(BaseEngine):
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
if stop is not None: if stop is not None:
raise ValueError("Stop parameter is not supported in Huggingface engine yet.") logger.warning("Stop parameter is not supported in Huggingface engine yet.")
generating_args = generating_args.copy() generating_args = generating_args.copy()
generating_args.update( generating_args.update(