From 15a5eb664737ddd9cc780e5c0f5a0b780b18b2fb Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 18 Jun 2024 22:08:56 +0800 Subject: [PATCH] fix #4335 Former-commit-id: c96264bc477d65276557e9059cac7c550c4835a8 --- src/llamafactory/chat/base_engine.py | 5 ----- src/llamafactory/chat/chat_model.py | 2 -- src/llamafactory/chat/hf_engine.py | 10 ++++------ src/llamafactory/chat/vllm_engine.py | 3 --- 4 files changed, 4 insertions(+), 16 deletions(-) diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py index 92a51ebe..ccdf4c92 100644 --- a/src/llamafactory/chat/base_engine.py +++ b/src/llamafactory/chat/base_engine.py @@ -50,11 +50,6 @@ class BaseEngine(ABC): generating_args: "GeneratingArguments", ) -> None: ... - @abstractmethod - async def start( - self, - ) -> None: ... - @abstractmethod async def chat( self, diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 2a72f422..5c83fa67 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -49,8 +49,6 @@ class ChatModel: self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() - task = asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop) - task.result() def chat( self, diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index a7ff7015..30200456 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine): self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate) ) # must after fixing tokenizer to resize vocab self.generating_args = generating_args.to_dict() + self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1"))) @staticmethod def _process_args( @@ -259,9 +260,6 @@ class HuggingfaceEngine(BaseEngine): return scores - async def start(self) -> None: - self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) - async def chat( self, messages: Sequence[Dict[str, str]], @@ -286,7 +284,7 @@ class HuggingfaceEngine(BaseEngine): image, input_kwargs, ) - async with self._semaphore: + async with self.semaphore: with concurrent.futures.ThreadPoolExecutor() as pool: return await loop.run_in_executor(pool, self._chat, *input_args) @@ -314,7 +312,7 @@ class HuggingfaceEngine(BaseEngine): image, input_kwargs, ) - async with self._semaphore: + async with self.semaphore: with concurrent.futures.ThreadPoolExecutor() as pool: stream = self._stream_chat(*input_args) while True: @@ -333,6 +331,6 @@ class HuggingfaceEngine(BaseEngine): loop = asyncio.get_running_loop() input_args = (self.model, self.tokenizer, batch_input, input_kwargs) - async with self._semaphore: + async with self.semaphore: with concurrent.futures.ThreadPoolExecutor() as pool: return await loop.run_in_executor(pool, self._get_scores, *input_args) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index d488a039..2626d612 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -183,9 +183,6 @@ class VllmEngine(BaseEngine): ) return result_generator - async def start(self) -> None: - pass - async def chat( self, messages: Sequence[Dict[str, str]],