3 Commits

Author SHA1 Message Date
Yaowei Zheng
df4c45c9ae [ci] fix workflow (#9738) 2026-01-09 14:48:16 +08:00
Yaowei Zheng
af3b6f5418 [model] clean obsolete models (#9736) 2026-01-09 14:08:18 +08:00
ZZHbible
5aacbe8434 [fix] fp8 (#9735)
Co-authored-by: jeremy.zhang <jeremy.zhang@temu.com>
2026-01-09 13:29:28 +08:00
6 changed files with 57 additions and 47 deletions

View File

@@ -230,7 +230,7 @@ def load_model(
)
from ..v1.plugins.model_plugins.kernels.interface import apply_default_kernels
model = apply_default_kernels(model, include_kernels=model_args.use_v1_kernels)
model = apply_default_kernels(model=model, include_kernels=model_args.use_v1_kernels)
trainable_params, all_param = count_parameters(model)
if is_trainable:

View File

@@ -15,11 +15,11 @@
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Generator
from threading import Thread
import torch
from transformers import AsyncTextIteratorStreamer
from transformers import TextIteratorStreamer
from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend
@@ -88,26 +88,39 @@ class HuggingFaceEngine(BaseEngine):
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = TextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = AsyncTextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
async for token in streamer:
yield token
response = self.get_response(messages, tools)
while True:
try:
yield await asyncio.to_thread(response)
except StopAsyncIteration:
break
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.

View File

@@ -28,9 +28,8 @@ Train Phase:
"""
from ..config.training_args import TrainingArguments
from ..utils.types import HFModel, TorchDataset
from .utils.data_collator import DataCollator
from .utils.rendering import Renderer
from ..utils.types import HFModel, Processor, TorchDataset
from .trainer_utils.data_collator import DataCollator
class BaseTrainer:
@@ -38,21 +37,21 @@ class BaseTrainer:
self,
args: TrainingArguments,
model: HFModel,
renderer: Renderer,
processor: Processor,
dataset: TorchDataset,
) -> None:
self.args = args
self.model = model
self.renderer = renderer
self.processor = processor
self.dataset = dataset
self.data_collator = DataCollator()
self.optimizer = None
self.lr_scheduler = None
def _create_dataloader(self) -> None:
def init_model_and_optimizer(self) -> None:
pass
def _init_model_and_optimizer(self) -> None:
def create_dataloader(self) -> None:
pass
def fit(self) -> None:

View File

@@ -87,7 +87,7 @@ class ModelEngine:
def _init_model(self) -> HFModel:
"""Init model.
Transformers can choose the proper model init context.
Let transformers handle the model init context.
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
"""
if self.args.model_class == ModelClass.LLM:
@@ -141,7 +141,7 @@ class ModelEngine:
from ..plugins.model_plugins.kernels.interface import KernelPlugin
model = KernelPlugin(self.args.kernel_config.name)(
model, include_kernels=self.args.kernel_config.get("include_kernels")
model=model, include_kernels=self.args.kernel_config.get("include_kernels")
)
return model

View File

@@ -24,13 +24,12 @@ Init Phase:
import importlib
from pathlib import Path
from ....utils import logging
from ....utils.logging import get_logger
from ....utils.plugin import BasePlugin
from ....utils.types import HFModel
from .registry import Registry
logger = logging.get_logger(__name__)
logger = get_logger(__name__)
def scan_all_kernels():
@@ -111,30 +110,27 @@ class KernelPlugin(BasePlugin):
@KernelPlugin("auto").register()
def apply_default_kernels(model: HFModel, include_kernels: str = None) -> HFModel:
def apply_default_kernels(**kwargs):
"""Applies all default registered kernels to the model.
Args:
model (HFModel): The model instance to apply kernels to.
include_kernels (str, optional): Comma-separated list of kernel IDs to apply.
If "auto" or True, applies all default kernels.
If None or False, no kernels are applied.
Defaults to None.
**kwargs: Keyword arguments passed to the kernel application function.
Typically includes the model instance and the include_kernels configuration.
Returns:
HFModel: The model with applied kernels.
"""
if not include_kernels:
return model
elif include_kernels == "auto" or include_kernels is True:
if not kwargs.get("include_kernels"): # None/False/empty string
return kwargs.get("model")
elif kwargs.get("include_kernels") == "auto" or kwargs.get("include_kernels") is True: # True/auto
use_kernels = default_kernels.keys()
else:
use_kernels = include_kernels.split(",") # "kernel_id1,kernel_id2,kernel_id3"
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, model=model)
apply_kernel(kernel, **kwargs)
return model
return kwargs.get("model")

View File

@@ -20,6 +20,8 @@ Init Phase:
"""
from typing import Optional
from ....accelerator.helper import get_current_accelerator
from .base import BaseKernel
@@ -71,14 +73,14 @@ class Registry:
return kernel_cls
@classmethod
def get(cls, kernel_id: str) -> type[BaseKernel] | None:
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
Returns:
type[BaseKernel] | None: The kernel class if found, else ``None``.
Optional[type[BaseKernel]]: The kernel class if found, else ``None``.
"""
return cls._kernels.get(kernel_id)