Former-commit-id: ba97550671811a27177306dd231bb427130b26fb
This commit is contained in:
hiyouga
2024-01-20 23:22:09 +08:00
parent 841fa0030f
commit 7543dc4a9d
5 changed files with 316 additions and 282 deletions

View File

@@ -1,11 +1,11 @@
from dataclasses import dataclass
from threading import Thread
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
import torch
from transformers import GenerationConfig, TextIteratorStreamer
from ..data import Role, get_template_and_fix_tokenizer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..hparams import get_infer_args
from ..model import dispatch_model, load_model_and_tokenizer
@@ -32,20 +32,11 @@ class ChatModel:
def _process_args(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> Tuple[Dict[str, Any], int]:
messages = []
if history is not None:
for old_prompt, old_response in history:
messages.append({"role": Role.USER, "content": old_prompt})
messages.append({"role": Role.ASSISTANT, "content": old_response})
messages.append({"role": Role.USER, "content": query})
messages.append({"role": Role.ASSISTANT, "content": ""})
prompt, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
)
@@ -97,18 +88,12 @@ class ChatModel:
@torch.inference_mode()
def chat(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> List[Response]:
r"""
Args: query, history, system, **input_kwargs
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
"""
gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
generate_output = self.model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:]
response = self.tokenizer.batch_decode(
@@ -132,13 +117,12 @@ class ChatModel:
@torch.inference_mode()
def stream_chat(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
messages: Sequence[Dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
**input_kwargs,
) -> Generator[str, None, None]:
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer