[v1] use async streamer (#9741)

This commit is contained in:
Yaowei Zheng
2026-01-09 16:07:40 +08:00
committed by hiyouga
parent 766d5ae6ad
commit 8abb8fb533
6 changed files with 47 additions and 57 deletions

View File

@@ -15,11 +15,11 @@
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from collections.abc import AsyncGenerator
from threading import Thread
import torch
from transformers import TextIteratorStreamer
from transformers import AsyncTextIteratorStreamer
from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend
@@ -88,39 +88,26 @@ class HuggingFaceEngine(BaseEngine):
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = TextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
response = self.get_response(messages, tools)
while True:
try:
yield await asyncio.to_thread(response)
except StopAsyncIteration:
break
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = AsyncTextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
async for token in streamer:
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.

View File

@@ -28,8 +28,9 @@ Train Phase:
"""
from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
from ..utils.types import HFModel, TorchDataset
from .utils.data_collator import DataCollator
from .utils.rendering import Renderer
class BaseTrainer:
@@ -37,21 +38,21 @@ class BaseTrainer:
self,
args: TrainingArguments,
model: HFModel,
processor: Processor,
renderer: Renderer,
dataset: TorchDataset,
) -> None:
self.args = args
self.model = model
self.processor = processor
self.renderer = renderer
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None
def init_model_and_optimizer(self) -> None:
def _create_dataloader(self) -> None:
pass
def create_dataloader(self) -> None:
def _init_model_and_optimizer(self) -> None:
pass
def fit(self) -> None:

View File

@@ -87,7 +87,7 @@ class ModelEngine:
def _init_model(self) -> HFModel:
"""Init model.
Let transformers handle the model init context.
Transformers can choose the proper model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.model_class == ModelClass.LLM:
@@ -141,7 +141,7 @@ class ModelEngine:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
model, include_kernels=self.args.kernel_config.get("include_kernels")
)
return model