mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
parent
8f6eb1383d
commit
1173441661
@ -38,6 +38,11 @@ class BaseEngine(ABC):
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def start(
|
||||
self,
|
||||
) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
self,
|
||||
|
@ -29,6 +29,7 @@ class ChatModel:
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
self._thread.start()
|
||||
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
|
@ -36,7 +36,6 @@ class HuggingfaceEngine(BaseEngine):
|
||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||
self.generating_args = generating_args.to_dict()
|
||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
|
||||
@staticmethod
|
||||
def _process_args(
|
||||
@ -191,6 +190,9 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
return scores
|
||||
|
||||
async def start(self) -> None:
|
||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
|
@ -97,6 +97,9 @@ class VllmEngine(BaseEngine):
|
||||
)
|
||||
return result_generator
|
||||
|
||||
async def start(self) -> None:
|
||||
pass
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.packages import is_galore_available
|
||||
@ -131,6 +132,7 @@ def create_custom_optimzer(
|
||||
if not finetuning_args.use_galore:
|
||||
return None
|
||||
|
||||
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
|
||||
galore_params = []
|
||||
galore_targets = finetuning_args.galore_target.split(",")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user