diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 8c2462238..88856492d 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -230,7 +230,7 @@ def load_model( ) from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels - model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels) + model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels) trainable_params, all_param = count_parameters(model) if is_trainable: diff --git a/src/llamafactory/v1/core/base_sampler.py b/src/llamafactory/v1/core/base_sampler.py index aaf376a9f..efa90bec1 100644 --- a/src/llamafactory/v1/core/base_sampler.py +++ b/src/llamafactory/v1/core/base_sampler.py @@ -15,11 +15,11 @@ import asyncio import os from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator from threading import Thread import torch -from transformers import TextIteratorStreamer +from transformers import AsyncTextIteratorStreamer from ..accelerator.interface import DistributedInterface from ..config import ModelArguments, SampleArguments, SampleBackend @@ -88,39 +88,26 @@ class HuggingFaceEngine(BaseEngine): self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) @torch.inference_mode() - def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]: - model_inputs = self.renderer.render_messages(messages, tools, is_generate=True) - streamer = TextIteratorStreamer( - tokenizer=get_tokenizer(self.renderer.processor), - skip_prompt=True, - skip_special_tokens=True, # TODO: configurable - ) - device = DistributedInterface().current_device - kwargs = { - "input_ids": torch.tensor([model_inputs["input_ids"]]).to(device), - "attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device), - "max_new_tokens": self.args.max_new_tokens, - "streamer": streamer, - } - thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True) - thread.start() - - def stream(): - try: - return streamer.__next__() - except StopIteration: - raise StopAsyncIteration() - - return stream - async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: async with self.semaphore: - response = self.get_response(messages, tools) - while True: - try: - yield await asyncio.to_thread(response) - except StopAsyncIteration: - break + model_inputs = self.renderer.render_messages(messages, tools, is_generate=True) + streamer = AsyncTextIteratorStreamer( + tokenizer=get_tokenizer(self.renderer.processor), + skip_prompt=True, + skip_special_tokens=True, # TODO: configurable + ) + device = DistributedInterface().current_device + kwargs = { + "input_ids": torch.tensor([model_inputs["input_ids"]]).to(device), + "attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device), + "max_new_tokens": self.args.max_new_tokens, + "streamer": streamer, + } + thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True) + thread.start() + + async for token in streamer: + yield token async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: """Batch infer samples. diff --git a/src/llamafactory/v1/core/base_trainer.py b/src/llamafactory/v1/core/base_trainer.py index 22c106a8e..96b2f5e2e 100644 --- a/src/llamafactory/v1/core/base_trainer.py +++ b/src/llamafactory/v1/core/base_trainer.py @@ -28,8 +28,9 @@ Train Phase: """ from ..config.training_args import TrainingArguments -from ..utils.types import HFModel, Processor, TorchDataset -from .trainer_utils.data_collator import DataCollator +from ..utils.types import HFModel, TorchDataset +from .utils.data_collator import DataCollator +from .utils.rendering import Renderer class BaseTrainer: @@ -37,21 +38,21 @@ class BaseTrainer: self, args: TrainingArguments, model: HFModel, - processor: Processor, + renderer: Renderer, dataset: TorchDataset, ) -> None: self.args = args self.model = model - self.processor = processor + self.renderer = renderer self.dataset = dataset self.data_collator = DataCollator() self.optimizer = None self.lr_scheduler = None - def init_model_and_optimizer(self) -> None: + def _create_dataloader(self) -> None: pass - def create_dataloader(self) -> None: + def _init_model_and_optimizer(self) -> None: pass def fit(self) -> None: diff --git a/src/llamafactory/v1/core/model_engine.py b/src/llamafactory/v1/core/model_engine.py index f69b4c203..849d37fa9 100644 --- a/src/llamafactory/v1/core/model_engine.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -87,7 +87,7 @@ class ModelEngine: def _init_model(self) -> HFModel: """Init model. - Let transformers handle the model init context. + Transformers can choose the proper model init context. https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538 """ if self.args.model_class == ModelClass.LLM: @@ -141,7 +141,7 @@ class ModelEngine: from ..plugins.model_plugins.kernels.interface import KernelPlugin model = KernelPlugin(self.args.kernel_config.name)( - model=model, include_kernels=self.args.kernel_config.get("include_kernels") + model, include_kernels=self.args.kernel_config.get("include_kernels") ) return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py index 446cd57a6..7967a4328 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py @@ -24,12 +24,13 @@ Init Phase: import importlib from pathlib import Path -from ....utils.logging import get_logger +from ....utils import logging from ....utils.plugin import BasePlugin +from ....utils.types import HFModel from .registry import Registry -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def scan_all_kernels(): @@ -110,27 +111,30 @@ class KernelPlugin(BasePlugin): @KernelPlugin("auto").register() -def apply_default_kernels(**kwargs): +def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel: """Applies all default registered kernels to the model. Args: - **kwargs: Keyword arguments passed to the kernel application function. - Typically includes the model instance and the include_kernels configuration. + model (HFModel): The model instance to apply kernels to. + include_kernels (str, optional): Comma-separated list of kernel IDs to apply. + If "auto" or True, applies all default kernels. + If None or False, no kernels are applied. + Defaults to None. Returns: HFModel: The model with applied kernels. """ - if not kwargs.get("include_kernels"): # None/False/empty string - return kwargs.get("model") - elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto + if not include_kernels: + return model + elif include_kernels == "auto" or include_kernels is True: use_kernels = default_kernels.keys() else: - use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3" + use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3" for kernel in use_kernels: if kernel not in default_kernels: raise ValueError(f"Kernel {kernel} not found") - apply_kernel(kernel, **kwargs) + apply_kernel(kernel, model=model) - return kwargs.get("model") + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index 16a5c8e0f..2621e4bad 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -20,8 +20,6 @@ Init Phase: """ -from typing import Optional - from ....accelerator.helper import get_current_accelerator from .base import BaseKernel @@ -73,14 +71,14 @@ class Registry: return kernel_cls @classmethod - def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]: + def get(cls, kernel_id: str) -> type[BaseKernel] | None: """Retrieves a registered kernel implementation by its ID. Args: kernel_id (str): The ID of the kernel to retrieve. Returns: - Optional[type[BaseKernel]]: The kernel class if found, else ``None``. + type[BaseKernel] | None: The kernel class if found, else ``None``. """ return cls._kernels.get(kernel_id)