mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
@@ -1,11 +1,11 @@
|
||||
from dataclasses import dataclass
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, List, Literal, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from ..data import Role, get_template_and_fix_tokenizer
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..hparams import get_infer_args
|
||||
from ..model import dispatch_model, load_model_and_tokenizer
|
||||
@@ -32,20 +32,11 @@ class ChatModel:
|
||||
|
||||
def _process_args(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
messages = []
|
||||
if history is not None:
|
||||
for old_prompt, old_response in history:
|
||||
messages.append({"role": Role.USER, "content": old_prompt})
|
||||
messages.append({"role": Role.ASSISTANT, "content": old_response})
|
||||
|
||||
messages.append({"role": Role.USER, "content": query})
|
||||
messages.append({"role": Role.ASSISTANT, "content": ""})
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, messages=messages, system=system, tools=tools
|
||||
)
|
||||
@@ -97,18 +88,12 @@ class ChatModel:
|
||||
@torch.inference_mode()
|
||||
def chat(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List[Response]:
|
||||
r"""
|
||||
Args: query, history, system, **input_kwargs
|
||||
|
||||
Returns: [(response_text, prompt_length, response_length)] * n (default n=1)
|
||||
"""
|
||||
gen_kwargs, prompt_length = self._process_args(query, history, system, tools, **input_kwargs)
|
||||
gen_kwargs, prompt_length = self._process_args(messages, system, tools, **input_kwargs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(
|
||||
@@ -132,13 +117,12 @@ class ChatModel:
|
||||
@torch.inference_mode()
|
||||
def stream_chat(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self._process_args(query, history, system, tools, **input_kwargs)
|
||||
gen_kwargs, _ = self._process_args(messages, system, tools, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
|
||||
Reference in New Issue
Block a user