Former-commit-id: 412c52e325660e8b871ffd59f5564f84f46a143f
This commit is contained in:
hiyouga 2024-03-09 21:35:24 +08:00
parent 8f6eb1383d
commit 1173441661
5 changed files with 14 additions and 1 deletions

View File

@ -38,6 +38,11 @@ class BaseEngine(ABC):
generating_args: "GeneratingArguments",
) -> None: ...
@abstractmethod
async def start(
self,
) -> None: ...
@abstractmethod
async def chat(
self,

View File

@ -29,6 +29,7 @@ class ChatModel:
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
def chat(
self,

View File

@ -36,7 +36,6 @@ class HuggingfaceEngine(BaseEngine):
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.generating_args = generating_args.to_dict()
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
@staticmethod
def _process_args(
@ -191,6 +190,9 @@ class HuggingfaceEngine(BaseEngine):
return scores
async def start(self) -> None:
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
async def chat(
self,
messages: Sequence[Dict[str, str]],

View File

@ -97,6 +97,9 @@ class VllmEngine(BaseEngine):
)
return result_generator
async def start(self) -> None:
pass
async def chat(
self,
messages: Sequence[Dict[str, str]],

View File

@ -1,6 +1,7 @@
from typing import TYPE_CHECKING, Optional, Union
import torch
from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
@ -131,6 +132,7 @@ def create_custom_optimzer(
if not finetuning_args.use_galore:
return None
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
galore_params = []
galore_targets = finetuning_args.galore_target.split(",")