[v1] use async streamer (#9741)

This commit is contained in:
Yaowei Zheng
2026-01-09 16:07:40 +08:00
committed by hiyouga
parent 766d5ae6ad
commit 8abb8fb533
6 changed files with 47 additions and 57 deletions

View File

@@ -230,7 +230,7 @@ def load_model(
) )
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels) model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
if is_trainable: if is_trainable:

View File

@@ -15,11 +15,11 @@
import asyncio import asyncio
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator from collections.abc import AsyncGenerator
from threading import Thread from threading import Thread
import torch import torch
from transformers import TextIteratorStreamer from transformers import AsyncTextIteratorStreamer
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend from ..config import ModelArguments, SampleArguments, SampleBackend
@@ -88,39 +88,26 @@ class HuggingFaceEngine(BaseEngine):
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode() @torch.inference_mode()
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = TextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore: async with self.semaphore:
response = self.get_response(messages, tools) model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
while True: streamer = AsyncTextIteratorStreamer(
try: tokenizer=get_tokenizer(self.renderer.processor),
yield await asyncio.to_thread(response) skip_prompt=True,
except StopAsyncIteration: skip_special_tokens=True, # TODO: configurable
break )
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
async for token in streamer:
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples. """Batch infer samples.

View File

@@ -28,8 +28,9 @@ Train Phase:
""" """
from ..config.training_args import TrainingArguments from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, Processor, TorchDataset from ..utils.types import HFModel, TorchDataset
from .trainer_utils.data_collator import DataCollator from .utils.data_collator import DataCollator
from .utils.rendering import Renderer
class BaseTrainer: class BaseTrainer:
@@ -37,21 +38,21 @@ class BaseTrainer:
self, self,
args: TrainingArguments, args: TrainingArguments,
model: HFModel, model: HFModel,
processor: Processor, renderer: Renderer,
dataset: TorchDataset, dataset: TorchDataset,
) -> None: ) -> None:
self.args = args self.args = args
self.model = model self.model = model
self.processor = processor self.renderer = renderer
self.dataset = dataset self.dataset = dataset
self.data_collator = DataCollator() self.data_collator = DataCollator()
self.optimizer = None self.optimizer = None
self.lr_scheduler = None self.lr_scheduler = None
def init_model_and_optimizer(self) -> None: def _create_dataloader(self) -> None:
pass pass
def create_dataloader(self) -> None: def _init_model_and_optimizer(self) -> None:
pass pass
def fit(self) -> None: def fit(self) -> None:

View File

@@ -87,7 +87,7 @@ class ModelEngine:
def _init_model(self) -> HFModel: def _init_model(self) -> HFModel:
"""Init model. """Init model.
Let transformers handle the model init context. Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538 https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
""" """
if self.args.model_class == ModelClass.LLM: if self.args.model_class == ModelClass.LLM:
@@ -141,7 +141,7 @@ class ModelEngine:
from ..plugins.model_plugins.kernels.interface import KernelPlugin from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)( model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels") model, include_kernels=self.args.kernel_config.get("include_kernels")
) )
return model return model

View File

@@ -24,12 +24,13 @@ Init Phase:
import importlib import importlib
from pathlib import Path from pathlib import Path
from ....utils.logging import get_logger from ....utils import logging
from ....utils.plugin import BasePlugin from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
from .registry import Registry from .registry import Registry
logger = get_logger(__name__) logger = logging.get_logger(__name__)
def scan_all_kernels(): def scan_all_kernels():
@@ -110,27 +111,30 @@ class KernelPlugin(BasePlugin):
@KernelPlugin("auto").register() @KernelPlugin("auto").register()
def apply_default_kernels(**kwargs): def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
"""Applies all default registered kernels to the model. """Applies all default registered kernels to the model.
Args: Args:
**kwargs: Keyword arguments passed to the kernel application function. model (HFModel): The model instance to apply kernels to.
Typically includes the model instance and the include_kernels configuration. include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
If "auto" or True, applies all default kernels.
If None or False, no kernels are applied.
Defaults to None.
Returns: Returns:
HFModel: The model with applied kernels. HFModel: The model with applied kernels.
""" """
if not kwargs.get("include_kernels"): # None/False/empty string if not include_kernels:
return kwargs.get("model") return model
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto elif include_kernels == "auto" or include_kernels is True:
use_kernels = default_kernels.keys() use_kernels = default_kernels.keys()
else: else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3" use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels: for kernel in use_kernels:
if kernel not in default_kernels: if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found") raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs) apply_kernel(kernel, model=model)
return kwargs.get("model") return model

View File

@@ -20,8 +20,6 @@ Init Phase:
""" """
from typing import Optional
from ....accelerator.helper import get_current_accelerator from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel from .base import BaseKernel
@@ -73,14 +71,14 @@ class Registry:
return kernel_cls return kernel_cls
@classmethod @classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]: def get(cls, kernel_id: str) -> type[BaseKernel] | None:
"""Retrieves a registered kernel implementation by its ID. """Retrieves a registered kernel implementation by its ID.
Args: Args:
kernel_id (str): The ID of the kernel to retrieve. kernel_id (str): The ID of the kernel to retrieve.
Returns: Returns:
Optional[type[BaseKernel]]: The kernel class if found, else ``None``. type[BaseKernel] | None: The kernel class if found, else ``None``.
""" """
return cls._kernels.get(kernel_id) return cls._kernels.get(kernel_id)