mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
parent
8f6eb1383d
commit
1173441661
@ -38,6 +38,11 @@ class BaseEngine(ABC):
|
|||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def start(
|
||||||
|
self,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
|
@ -29,6 +29,7 @@ class ChatModel:
|
|||||||
self._loop = asyncio.new_event_loop()
|
self._loop = asyncio.new_event_loop()
|
||||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
|
asyncio.run_coroutine_threadsafe(self.engine.start(), self._loop)
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
self,
|
self,
|
||||||
|
@ -36,7 +36,6 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
self.tokenizer.padding_side = "left" if self.can_generate else "right"
|
||||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
|
||||||
self.generating_args = generating_args.to_dict()
|
self.generating_args = generating_args.to_dict()
|
||||||
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_args(
|
def _process_args(
|
||||||
@ -191,6 +190,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
self._semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", 1)))
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
|
@ -97,6 +97,9 @@ class VllmEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
return result_generator
|
return result_generator
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_galore_available
|
from ..extras.packages import is_galore_available
|
||||||
@ -131,6 +132,7 @@ def create_custom_optimzer(
|
|||||||
if not finetuning_args.use_galore:
|
if not finetuning_args.use_galore:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
require_version("galore_torch", "To fix: pip install git+https://github.com/hiyouga/GaLore.git")
|
||||||
galore_params = []
|
galore_params = []
|
||||||
galore_targets = finetuning_args.galore_target.split(",")
|
galore_targets = finetuning_args.galore_target.split(",")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user