mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-24 06:39:08 +08:00
[v1] use async streamer (#9741)
This commit is contained in:
@@ -15,11 +15,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from collections.abc import AsyncGenerator
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers import AsyncTextIteratorStreamer
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||
@@ -88,39 +88,26 @@ class HuggingFaceEngine(BaseEngine):
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
)
|
||||
device = DistributedInterface().current_device
|
||||
kwargs = {
|
||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||
"max_new_tokens": self.args.max_new_tokens,
|
||||
"streamer": streamer,
|
||||
}
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def stream():
|
||||
try:
|
||||
return streamer.__next__()
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
response = self.get_response(messages, tools)
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.to_thread(response)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = AsyncTextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
)
|
||||
device = DistributedInterface().current_device
|
||||
kwargs = {
|
||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||
"max_new_tokens": self.args.max_new_tokens,
|
||||
"streamer": streamer,
|
||||
}
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
async for token in streamer:
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
@@ -28,8 +28,9 @@ Train Phase:
|
||||
"""
|
||||
|
||||
from ..config.training_args import TrainingArguments
|
||||
from ..utils.types import HFModel, Processor, TorchDataset
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
from ..utils.types import HFModel, TorchDataset
|
||||
from .utils.data_collator import DataCollator
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
@@ -37,21 +38,21 @@ class BaseTrainer:
|
||||
self,
|
||||
args: TrainingArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
renderer: Renderer,
|
||||
dataset: TorchDataset,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.renderer = renderer
|
||||
self.dataset = dataset
|
||||
self.data_collator = DataCollator()
|
||||
self.optimizer = None
|
||||
self.lr_scheduler = None
|
||||
|
||||
def init_model_and_optimizer(self) -> None:
|
||||
def _create_dataloader(self) -> None:
|
||||
pass
|
||||
|
||||
def create_dataloader(self) -> None:
|
||||
def _init_model_and_optimizer(self) -> None:
|
||||
pass
|
||||
|
||||
def fit(self) -> None:
|
||||
|
||||
@@ -87,7 +87,7 @@ class ModelEngine:
|
||||
def _init_model(self) -> HFModel:
|
||||
"""Init model.
|
||||
|
||||
Let transformers handle the model init context.
|
||||
Transformers can choose the proper model init context.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||
"""
|
||||
if self.args.model_class == ModelClass.LLM:
|
||||
@@ -141,7 +141,7 @@ class ModelEngine:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user