mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 10:56:56 +08:00
fix chat engine, update webui
Former-commit-id: 8b32dddd7d883bae07735796a517927c79d1c33b
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
|
||||
from ..hparams import get_infer_args
|
||||
@@ -10,21 +11,24 @@ if TYPE_CHECKING:
|
||||
from .base_engine import BaseEngine, Response
|
||||
|
||||
|
||||
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
class ChatModel:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||
if model_args.infer_backend == "hf":
|
||||
if model_args.infer_backend == "huggingface":
|
||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
elif model_args.infer_backend == "vllm":
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||
|
||||
def _get_event_loop():
|
||||
try:
|
||||
return asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.new_event_loop()
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def chat(
|
||||
self,
|
||||
@@ -33,8 +37,8 @@ class ChatModel:
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
loop = self._get_event_loop()
|
||||
return loop.run_until_complete(self.achat(messages, system, tools, **input_kwargs))
|
||||
task = asyncio.run_coroutine_threadsafe(self.achat(messages, system, tools, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def achat(
|
||||
self,
|
||||
@@ -52,11 +56,11 @@ class ChatModel:
|
||||
tools: Optional[str] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
loop = self._get_event_loop()
|
||||
generator = self.astream_chat(messages, system, tools, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
yield loop.run_until_complete(generator.__anext__())
|
||||
task = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop)
|
||||
yield task.result()
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@@ -75,8 +79,8 @@ class ChatModel:
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
loop = self._get_event_loop()
|
||||
return loop.run_until_complete(self.aget_scores(batch_input, **input_kwargs))
|
||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
async def aget_scores(
|
||||
self,
|
||||
|
||||
@@ -147,7 +147,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs)
|
||||
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def stream():
|
||||
|
||||
Reference in New Issue
Block a user