From 1173441661840c7753b35fa72979c5e7c82824f1 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 9 Mar 2024 21:35:24 +0800 Subject: [PATCH] fix #2766 Former-commit-id: 412c52e325660e8b871ffd59f5564f84f46a143f --- src/llmtuner/chat/base_engine.py | 5 +++++ src/llmtuner/chat/chat_model.py | 1 + src/llmtuner/chat/hf_engine.py | 4 +++- src/llmtuner/chat/vllm_engine.py | 3 +++ src/llmtuner/train/utils.py | 2 ++ 5 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/chat/base_engine.py b/src/llmtuner/chat/base_engine.py index ea46bfba..c5db41da 100644 --- a/src/llmtuner/chat/base_engine.py +++ b/src/llmtuner/chat/base_engine.py @@ -38,6 +38,11 @@ class BaseEngine(ABC): generating_args: "GeneratingArguments", ) -> None: ... + @abstractmethod + async def start( + self, + ) -> None: ... + @abstractmethod async def chat( self, diff --git a/src/llmtuner/chat/chat_model.py b/src/llmtuner/chat/chat_model.py index 9b509180..c49d4d78 100644 --- a/src/llmtuner/chat/chat_model.py +++ b/src/llmtuner/chat/chat_model.py @@ -29,6 +29,7 @@ class ChatModel: self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() + asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop) def chat( self, diff --git a/src/llmtuner/chat/hf_engine.py b/src/llmtuner/chat/hf_engine.py index 7def2b75..c634ba16 100644 --- a/src/llmtuner/chat/hf_engine.py +++ b/src/llmtuner/chat/hf_engine.py @@ -36,7 +36,6 @@ class HuggingfaceEngine(BaseEngine): self.tokenizer.padding_side = "left" if self.can_generate else "right" self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template) self.generating_args = generating_args.to_dict() - self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) @staticmethod def _process_args( @@ -191,6 +190,9 @@ class HuggingfaceEngine(BaseEngine): return scores + async def start(self) -> None: + self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1))) + async def chat( self, messages: Sequence[Dict[str, str]], diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index 258accb6..b147d19b 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -97,6 +97,9 @@ class VllmEngine(BaseEngine): ) return result_generator + async def start(self) -> None: + pass + async def chat( self, messages: Sequence[Dict[str, str]], diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 77ed0f04..39fb77f4 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -1,6 +1,7 @@ from typing import TYPE_CHECKING, Optional, Union import torch +from transformers.utils.versions import require_version from ..extras.logging import get_logger from ..extras.packages import is_galore_available @@ -131,6 +132,7 @@ def create_custom_optimzer( if not finetuning_args.use_galore: return None + require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git") galore_params = [] galore_targets = finetuning_args.galore_target.split(",")