mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[v1] use async streamer (#9741)
This commit is contained in:
@@ -230,7 +230,7 @@ def load_model(
|
|||||||
)
|
)
|
||||||
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||||
|
|
||||||
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
|
model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels)
|
||||||
|
|
||||||
trainable_params, all_param = count_parameters(model)
|
trainable_params, all_param = count_parameters(model)
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
|
|||||||
@@ -15,11 +15,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import AsyncTextIteratorStreamer
|
||||||
|
|
||||||
from ..accelerator.interface import DistributedInterface
|
from ..accelerator.interface import DistributedInterface
|
||||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||||
@@ -88,9 +88,10 @@ class HuggingFaceEngine(BaseEngine):
|
|||||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||||
|
async with self.semaphore:
|
||||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||||
streamer = TextIteratorStreamer(
|
streamer = AsyncTextIteratorStreamer(
|
||||||
tokenizer=get_tokenizer(self.renderer.processor),
|
tokenizer=get_tokenizer(self.renderer.processor),
|
||||||
skip_prompt=True,
|
skip_prompt=True,
|
||||||
skip_special_tokens=True, # TODO: configurable
|
skip_special_tokens=True, # TODO: configurable
|
||||||
@@ -105,22 +106,8 @@ class HuggingFaceEngine(BaseEngine):
|
|||||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
def stream():
|
async for token in streamer:
|
||||||
try:
|
yield token
|
||||||
return streamer.__next__()
|
|
||||||
except StopIteration:
|
|
||||||
raise StopAsyncIteration()
|
|
||||||
|
|
||||||
return stream
|
|
||||||
|
|
||||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
|
||||||
async with self.semaphore:
|
|
||||||
response = self.get_response(messages, tools)
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
yield await asyncio.to_thread(response)
|
|
||||||
except StopAsyncIteration:
|
|
||||||
break
|
|
||||||
|
|
||||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||||
"""Batch infer samples.
|
"""Batch infer samples.
|
||||||
|
|||||||
@@ -28,8 +28,9 @@ Train Phase:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from ..config.training_args import TrainingArguments
|
from ..config.training_args import TrainingArguments
|
||||||
from ..utils.types import HFModel, Processor, TorchDataset
|
from ..utils.types import HFModel, TorchDataset
|
||||||
from .trainer_utils.data_collator import DataCollator
|
from .utils.data_collator import DataCollator
|
||||||
|
from .utils.rendering import Renderer
|
||||||
|
|
||||||
|
|
||||||
class BaseTrainer:
|
class BaseTrainer:
|
||||||
@@ -37,21 +38,21 @@ class BaseTrainer:
|
|||||||
self,
|
self,
|
||||||
args: TrainingArguments,
|
args: TrainingArguments,
|
||||||
model: HFModel,
|
model: HFModel,
|
||||||
processor: Processor,
|
renderer: Renderer,
|
||||||
dataset: TorchDataset,
|
dataset: TorchDataset,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
self.processor = processor
|
self.renderer = renderer
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.data_collator = DataCollator()
|
self.data_collator = DataCollator()
|
||||||
self.optimizer = None
|
self.optimizer = None
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
|
|
||||||
def init_model_and_optimizer(self) -> None:
|
def _create_dataloader(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_dataloader(self) -> None:
|
def _init_model_and_optimizer(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def fit(self) -> None:
|
def fit(self) -> None:
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ class ModelEngine:
|
|||||||
def _init_model(self) -> HFModel:
|
def _init_model(self) -> HFModel:
|
||||||
"""Init model.
|
"""Init model.
|
||||||
|
|
||||||
Let transformers handle the model init context.
|
Transformers can choose the proper model init context.
|
||||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||||
"""
|
"""
|
||||||
if self.args.model_class == ModelClass.LLM:
|
if self.args.model_class == ModelClass.LLM:
|
||||||
@@ -141,7 +141,7 @@ class ModelEngine:
|
|||||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||||
|
|
||||||
model = KernelPlugin(self.args.kernel_config.name)(
|
model = KernelPlugin(self.args.kernel_config.name)(
|
||||||
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
|
model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||||
)
|
)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
@@ -24,12 +24,13 @@ Init Phase:
|
|||||||
import importlib
|
import importlib
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from ....utils.logging import get_logger
|
from ....utils import logging
|
||||||
from ....utils.plugin import BasePlugin
|
from ....utils.plugin import BasePlugin
|
||||||
|
from ....utils.types import HFModel
|
||||||
from .registry import Registry
|
from .registry import Registry
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def scan_all_kernels():
|
def scan_all_kernels():
|
||||||
@@ -110,27 +111,30 @@ class KernelPlugin(BasePlugin):
|
|||||||
|
|
||||||
|
|
||||||
@KernelPlugin("auto").register()
|
@KernelPlugin("auto").register()
|
||||||
def apply_default_kernels(**kwargs):
|
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
|
||||||
"""Applies all default registered kernels to the model.
|
"""Applies all default registered kernels to the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**kwargs: Keyword arguments passed to the kernel application function.
|
model (HFModel): The model instance to apply kernels to.
|
||||||
Typically includes the model instance and the include_kernels configuration.
|
include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
|
||||||
|
If "auto" or True, applies all default kernels.
|
||||||
|
If None or False, no kernels are applied.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
HFModel: The model with applied kernels.
|
HFModel: The model with applied kernels.
|
||||||
"""
|
"""
|
||||||
if not kwargs.get("include_kernels"): # None/False/empty string
|
if not include_kernels:
|
||||||
return kwargs.get("model")
|
return model
|
||||||
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
|
elif include_kernels == "auto" or include_kernels is True:
|
||||||
use_kernels = default_kernels.keys()
|
use_kernels = default_kernels.keys()
|
||||||
else:
|
else:
|
||||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||||
|
|
||||||
for kernel in use_kernels:
|
for kernel in use_kernels:
|
||||||
if kernel not in default_kernels:
|
if kernel not in default_kernels:
|
||||||
raise ValueError(f"Kernel {kernel} not found")
|
raise ValueError(f"Kernel {kernel} not found")
|
||||||
|
|
||||||
apply_kernel(kernel, **kwargs)
|
apply_kernel(kernel, model=model)
|
||||||
|
|
||||||
return kwargs.get("model")
|
return model
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ Init Phase:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from ....accelerator.helper import get_current_accelerator
|
from ....accelerator.helper import get_current_accelerator
|
||||||
from .base import BaseKernel
|
from .base import BaseKernel
|
||||||
|
|
||||||
@@ -73,14 +71,14 @@ class Registry:
|
|||||||
return kernel_cls
|
return kernel_cls
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
|
||||||
"""Retrieves a registered kernel implementation by its ID.
|
"""Retrieves a registered kernel implementation by its ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kernel_id (str): The ID of the kernel to retrieve.
|
kernel_id (str): The ID of the kernel to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
|
type[BaseKernel] | None: The kernel class if found, else ``None``.
|
||||||
"""
|
"""
|
||||||
return cls._kernels.get(kernel_id)
|
return cls._kernels.get(kernel_id)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user