mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[v1] use async streamer (#9741)
This commit is contained in:
@@ -230,7 +230,7 @@ def load_model(
|
||||
)
|
||||
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
|
||||
|
||||
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
|
||||
model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels)
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
if is_trainable:
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from collections.abc import AsyncGenerator
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers import AsyncTextIteratorStreamer
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||
@@ -88,9 +88,10 @@ class HuggingFaceEngine(BaseEngine):
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = TextIteratorStreamer(
|
||||
streamer = AsyncTextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
@@ -105,22 +106,8 @@ class HuggingFaceEngine(BaseEngine):
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def stream():
|
||||
try:
|
||||
return streamer.__next__()
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
response = self.get_response(messages, tools)
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.to_thread(response)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
async for token in streamer:
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
@@ -28,8 +28,9 @@ Train Phase:
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..utils.types import HFModel, Processor, TorchDataset
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
from ..utils.types import HFModel, TorchDataset
|
||||
from .utils.data_collator import DataCollator
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
@@ -37,21 +38,21 @@ class BaseTrainer:
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
renderer: Renderer,
|
||||
dataset: TorchDataset,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.renderer = renderer
|
||||
self.dataset = dataset
|
||||
self.data_collator = DataCollator()
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
def _create_dataloader(self) -> None:
|
||||
pass
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
def _init_model_and_optimizer(self) -> None:
|
||||
pass
|
||||
|
||||
def fit(self) -> None:
|
||||
|
||||
@@ -87,7 +87,7 @@ class ModelEngine:
|
||||
def _init_model(self) -> HFModel:
|
||||
"""Init model.
|
||||
|
||||
Let transformers handle the model init context.
|
||||
Transformers can choose the proper model init context.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||
"""
|
||||
if self.args.model_class == ModelClass.LLM:
|
||||
@@ -141,7 +141,7 @@ class ModelEngine:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@@ -24,12 +24,13 @@ Init Phase:
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils import logging
|
||||
from ....utils.plugin import BasePlugin
|
||||
from ....utils.types import HFModel
|
||||
from .registry import Registry
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def scan_all_kernels():
|
||||
@@ -110,27 +111,30 @@ class KernelPlugin(BasePlugin):
|
||||
|
||||
|
||||
@KernelPlugin("auto").register()
|
||||
def apply_default_kernels(**kwargs):
|
||||
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
|
||||
"""Applies all default registered kernels to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
Typically includes the model instance and the include_kernels configuration.
|
||||
model (HFModel): The model instance to apply kernels to.
|
||||
include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
|
||||
If "auto" or True, applies all default kernels.
|
||||
If None or False, no kernels are applied.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
HFModel: The model with applied kernels.
|
||||
"""
|
||||
if not kwargs.get("include_kernels"): # None/False/empty string
|
||||
return kwargs.get("model")
|
||||
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
|
||||
if not include_kernels:
|
||||
return model
|
||||
elif include_kernels == "auto" or include_kernels is True:
|
||||
use_kernels = default_kernels.keys()
|
||||
else:
|
||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
|
||||
for kernel in use_kernels:
|
||||
if kernel not in default_kernels:
|
||||
raise ValueError(f"Kernel {kernel} not found")
|
||||
|
||||
apply_kernel(kernel, **kwargs)
|
||||
apply_kernel(kernel, model=model)
|
||||
|
||||
return kwargs.get("model")
|
||||
return model
|
||||
|
||||
@@ -20,8 +20,6 @@ Init Phase:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from .base import BaseKernel
|
||||
|
||||
@@ -73,14 +71,14 @@ class Registry:
|
||||
return kernel_cls
|
||||
|
||||
@classmethod
|
||||
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
||||
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
|
||||
"""Retrieves a registered kernel implementation by its ID.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to retrieve.
|
||||
|
||||
Returns:
|
||||
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
|
||||
type[BaseKernel] | None: The kernel class if found, else ``None``.
|
||||
"""
|
||||
return cls._kernels.get(kernel_id)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user