fix chat engine, update webui

Former-commit-id: 8b32dddd7d883bae07735796a517927c79d1c33b
This commit is contained in:
hiyouga
2024-03-08 03:01:53 +08:00
parent 8042c66a76
commit 48d4364586
9 changed files with 250 additions and 83 deletions

View File

@@ -1,4 +1,5 @@
import asyncio
from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..hparams import get_infer_args
@@ -10,21 +11,24 @@ if TYPE_CHECKING:
from .base_engine import BaseEngine, Response
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
if model_args.infer_backend == "hf":
if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm":
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
def _get_event_loop():
try:
return asyncio.get_running_loop()
except RuntimeError:
return asyncio.new_event_loop()
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
def chat(
self,
@@ -33,8 +37,8 @@ class ChatModel:
tools: Optional[str] = None,
**input_kwargs,
) -> List["Response"]:
loop = self._get_event_loop()
return loop.run_until_complete(self.achat(messages, system, tools, **input_kwargs))
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
return task.result()
async def achat(
self,
@@ -52,11 +56,11 @@ class ChatModel:
tools: Optional[str] = None,
**input_kwargs,
) -> Generator[str, None, None]:
loop = self._get_event_loop()
generator = self.astream_chat(messages, system, tools, **input_kwargs)
while True:
try:
yield loop.run_until_complete(generator.__anext__())
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
yield task.result()
except StopAsyncIteration:
break
@@ -75,8 +79,8 @@ class ChatModel:
batch_input: List[str],
**input_kwargs,
) -> List[float]:
loop = self._get_event_loop()
return loop.run_until_complete(self.aget_scores(batch_input, **input_kwargs))
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result()
async def aget_scores(
self,

View File

@@ -147,7 +147,7 @@ class HuggingfaceEngine(BaseEngine):
)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs)
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start()
def stream():