From ea0b4e2466e7e250dbf0008d53343d7e6e62b21c Mon Sep 17 00:00:00 2001 From: Yaowei Zheng Date: Tue, 6 Jan 2026 23:31:27 +0800 Subject: [PATCH] [v1] add cli sampler (#9721) --- .github/workflows/tests_cuda.yml | 4 +- pyproject.toml | 2 +- src/llamafactory/v1/accelerator/helper.py | 16 +- src/llamafactory/v1/accelerator/interface.py | 17 +- src/llamafactory/v1/config/arg_utils.py | 6 +- src/llamafactory/v1/config/model_args.py | 5 + src/llamafactory/v1/core/base_sampler.py | 134 ++++++++-- src/llamafactory/v1/core/data_engine.py | 22 +- .../core/{model_loader.py => model_engine.py} | 35 +-- .../core/{trainer_utils => utils}/__init__.py | 0 .../core/{trainer_utils => utils}/callback.py | 0 .../{trainer_utils => utils}/data_collator.py | 0 .../{trainer_utils => utils}/data_loader.py | 0 .../{trainer_utils => utils}/lr_scheduler.py | 0 src/llamafactory/v1/core/utils/rendering.py | 239 ++++++++++++++++++ src/llamafactory/v1/launcher.py | 5 + .../v1/plugins/data_plugins/converter.py | 37 ++- .../v1/plugins/data_plugins/loader.py | 72 +++--- .../v1/plugins/data_plugins/template.py | 133 ---------- .../plugins/model_plugins/initialization.py | 8 +- .../v1/plugins/model_plugins/kernels/base.py | 8 +- .../model_plugins/kernels/interface.py | 16 +- .../kernels/ops/mlp/npu_fused_moe.py | 25 +- .../kernels/ops/mlp/npu_swiglu.py | 10 +- .../kernels/ops/rms_norm/npu_rms_norm.py | 7 +- .../kernels/ops/rope/npu_rope.py | 11 +- .../plugins/model_plugins/kernels/registry.py | 13 +- .../v1/plugins/model_plugins/peft.py | 4 +- .../v1/plugins/model_plugins/rendering.py | 36 +++ src/llamafactory/v1/samplers/cli_sampler.py | 102 +++++++- src/llamafactory/v1/trainers/sft_trainer.py | 8 +- src/llamafactory/v1/utils/constants.py | 2 + src/llamafactory/v1/utils/helper.py | 29 +++ src/llamafactory/v1/utils/logging.py | 2 +- src/llamafactory/v1/utils/plugin.py | 45 +++- src/llamafactory/v1/utils/types.py | 36 ++- tests/conftest.py | 12 +- tests/version.txt | 2 +- tests_v1/core/test_data_loader.py | 173 ------------- tests_v1/core/test_model_loader.py | 21 +- tests_v1/core/utils/test_data_loader.py | 171 +++++++++++++ tests_v1/core/utils/test_rendering.py | 65 +++++ .../plugins/data_plugins/test_converter.py | 4 +- .../plugins/model_plugins/test_init_plugin.py | 18 +- tests_v1/sampler/test_cli_sampler.py | 41 +++ 45 files changed, 1091 insertions(+), 505 deletions(-) rename src/llamafactory/v1/core/{model_loader.py => model_engine.py} (85%) rename src/llamafactory/v1/core/{trainer_utils => utils}/__init__.py (100%) rename src/llamafactory/v1/core/{trainer_utils => utils}/callback.py (100%) rename src/llamafactory/v1/core/{trainer_utils => utils}/data_collator.py (100%) rename src/llamafactory/v1/core/{trainer_utils => utils}/data_loader.py (100%) rename src/llamafactory/v1/core/{trainer_utils => utils}/lr_scheduler.py (100%) create mode 100644 src/llamafactory/v1/core/utils/rendering.py delete mode 100644 src/llamafactory/v1/plugins/data_plugins/template.py create mode 100644 src/llamafactory/v1/plugins/model_plugins/rendering.py create mode 100644 src/llamafactory/v1/utils/helper.py delete mode 100644 tests_v1/core/test_data_loader.py create mode 100644 tests_v1/core/utils/test_data_loader.py create mode 100644 tests_v1/core/utils/test_rendering.py create mode 100644 tests_v1/sampler/test_cli_sampler.py diff --git a/.github/workflows/tests_cuda.yml b/.github/workflows/tests_cuda.yml index 1c51e1c7a..53548f509 100644 --- a/.github/workflows/tests_cuda.yml +++ b/.github/workflows/tests_cuda.yml @@ -55,12 +55,12 @@ jobs: uv pip install -e . uv pip install -r requirements/dev.txt - - name: Cache HuggingFace models + - name: Cache files id: hf-hub-cache uses: actions/cache@v4 with: path: ${{ runner.temp }}/huggingface - key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }} + key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }} - name: Check quality run: | diff --git a/pyproject.toml b/pyproject.toml index 357df9314..0faa5d9e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,7 @@ dependencies = [ # api "uvicorn", "fastapi", - "sse-starlette" + "sse-starlette", ] [project.scripts] diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py index 04ded900c..c9cee520c 100644 --- a/src/llamafactory/v1/accelerator/helper.py +++ b/src/llamafactory/v1/accelerator/helper.py @@ -119,9 +119,19 @@ def synchronize() -> None: @requires_accelerator -def set_device() -> None: - """Set current accelerator.""" - torch.accelerator.set_device_index(get_local_rank()) +def set_device_index() -> None: + """Set current accelerator index to local rank.""" + if get_current_accelerator().type != DeviceType.CPU: + torch.accelerator.set_device_index(get_local_rank()) + + +@requires_accelerator +def get_current_device() -> torch.device: + """Get current accelerator device.""" + if get_current_accelerator().type == DeviceType.CPU: + return torch.device(DeviceType.CPU.value) + else: + return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index()) def is_torch_cuda_available(): diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index c7cdf4f5f..2837cbc74 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -123,12 +123,13 @@ class DistributedInterface: if self._initialized: return + helper.set_device_index() self._is_distributed = helper.is_distributed() self._rank = helper.get_rank() self._world_size = helper.get_world_size() self._local_rank = helper.get_local_rank() self._local_world_size = helper.get_local_world_size() - self.current_accelerator = helper.get_current_accelerator() + self.current_device = helper.get_current_device() self.device_count = helper.get_device_count() if config is None: @@ -144,15 +145,14 @@ class DistributedInterface: timeout = config.get("timeout", 18000) if self._is_distributed: - helper.set_device() init_process_group(timeout=timedelta(seconds=timeout)) self.model_device_mesh = init_device_mesh( - device_type=self.current_accelerator.type, + device_type=self.current_device.type, mesh_shape=self.strategy.model_mesh_shape, mesh_dim_names=self.strategy.model_mesh_dim_names, ) self.data_device_mesh = init_device_mesh( - device_type=self.current_accelerator.type, + device_type=self.current_device.type, mesh_shape=self.strategy.data_mesh_shape, mesh_dim_names=self.strategy.data_mesh_dim_names, ) @@ -161,12 +161,12 @@ class DistributedInterface: self.data_device_mesh = None self._initialized = True - logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.") + logger.info_rank0(f"DistributedInterface initialized: {self}.") def __str__(self) -> str: return ( f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, " - f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, " + f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, " f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" ) @@ -251,4 +251,7 @@ class DistributedInterface: if __name__ == "__main__": - print(DistributedInterface(DistributedStrategy())) + """ + python -m llamafactory.v1.accelerator.interface + """ + print(DistributedInterface()) diff --git a/src/llamafactory/v1/config/arg_utils.py b/src/llamafactory/v1/config/arg_utils.py index 5335cdbb5..5673bd687 100644 --- a/src/llamafactory/v1/config/arg_utils.py +++ b/src/llamafactory/v1/config/arg_utils.py @@ -17,7 +17,7 @@ import json -from enum import Enum, unique +from enum import StrEnum, unique class PluginConfig(dict): @@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None @unique -class ModelClass(str, Enum): +class ModelClass(StrEnum): """Auto class for model config.""" LLM = "llm" @@ -45,7 +45,7 @@ class ModelClass(str, Enum): @unique -class SampleBackend(str, Enum): +class SampleBackend(StrEnum): HF = "hf" VLLM = "vllm" diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py index b79ee86de..d017cd149 100644 --- a/src/llamafactory/v1/config/model_args.py +++ b/src/llamafactory/v1/config/model_args.py @@ -21,8 +21,13 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config @dataclass class ModelArguments: model: str = field( + default="Qwen/Qwen3-4B-Instruct-2507", metadata={"help": "Path to the model or model identifier from Hugging Face."}, ) + template: str = field( + default="chatml", + metadata={"help": "Template for the model."}, + ) trust_remote_code: bool = field( default=False, metadata={"help": "Trust remote code from Hugging Face."}, diff --git a/src/llamafactory/v1/core/base_sampler.py b/src/llamafactory/v1/core/base_sampler.py index fbd36ab69..aaf376a9f 100644 --- a/src/llamafactory/v1/core/base_sampler.py +++ b/src/llamafactory/v1/core/base_sampler.py @@ -12,10 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import os from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator, Generator +from threading import Thread +import torch +from transformers import TextIteratorStreamer + +from ..accelerator.interface import DistributedInterface from ..config import ModelArguments, SampleArguments, SampleBackend -from ..utils.types import HFModel, Processor, TorchDataset +from ..utils.helper import get_tokenizer +from ..utils.types import HFModel, Message, Sample, TorchDataset +from .utils.rendering import Renderer class BaseEngine(ABC): @@ -24,8 +34,8 @@ class BaseEngine(ABC): self, args: SampleArguments, model_args: ModelArguments, - model: HFModel = None, - processor: Processor = None, + model: HFModel, + renderer: Renderer, ) -> None: """Initialize the engine. @@ -33,17 +43,34 @@ class BaseEngine(ABC): args: Sample arguments. model_args: Model arguments. model: Model. - processor: Processor. + renderer: Renderer. """ ... @abstractmethod - async def generate(self, messages): - pass + async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: + """Generate tokens asynchronously. + + Args: + messages: List of messages. + tools: Tools string. + + Yields: + Generated tokens. + """ + ... @abstractmethod - async def batch_infer(self, data: TorchDataset) -> None: - pass + async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: + """Batch infer samples. + + Args: + dataset: Torch dataset. + + Returns: + List of samples. + """ + ... class HuggingFaceEngine(BaseEngine): @@ -52,26 +79,103 @@ class HuggingFaceEngine(BaseEngine): args: SampleArguments, model_args: ModelArguments, model: HFModel, - processor: Processor, + renderer: Renderer, ) -> None: self.args = args + self.model_args = model_args + self.model = model + self.renderer = renderer + self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) + + @torch.inference_mode() + def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]: + model_inputs = self.renderer.render_messages(messages, tools, is_generate=True) + streamer = TextIteratorStreamer( + tokenizer=get_tokenizer(self.renderer.processor), + skip_prompt=True, + skip_special_tokens=True, # TODO: configurable + ) + device = DistributedInterface().current_device + kwargs = { + "input_ids": torch.tensor([model_inputs["input_ids"]]).to(device), + "attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device), + "max_new_tokens": self.args.max_new_tokens, + "streamer": streamer, + } + thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True) + thread.start() + + def stream(): + try: + return streamer.__next__() + except StopIteration: + raise StopAsyncIteration() + + return stream + + async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: + async with self.semaphore: + response = self.get_response(messages, tools) + while True: + try: + yield await asyncio.to_thread(response) + except StopAsyncIteration: + break + + async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: + """Batch infer samples. + + Args: + dataset: Torch dataset. + + Returns: + List of samples. + """ + raise NotImplementedError("Batch infer is not implemented.") class BaseSampler: + """Base sampler. + + Args: + args: Sample arguments. + model_args: Model arguments. + model: Model. + renderer: Renderer. + """ + def __init__( self, args: SampleArguments, model_args: ModelArguments, model: HFModel, - processor: Processor, + renderer: Renderer, ) -> None: if args.sample_backend == SampleBackend.HF: - self.engine = HuggingFaceEngine(args, model_args, model, processor) + self.engine = HuggingFaceEngine(args, model_args, model, renderer) else: raise ValueError(f"Unknown sample backend: {args.sample_backend}") - async def generate(self, messages): - return await self.engine.generate(messages) + async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]: + """Generate tokens asynchronously. - async def batch_infer(self, data: TorchDataset) -> None: - return await self.engine.batch_infer(data) + Args: + messages: List of messages. + tools: Tools string. + + Yields: + Generated tokens. + """ + async for token in self.engine.generate(messages, tools): + yield token + + async def batch_infer(self, dataset: TorchDataset) -> list[Sample]: + """Batch infer samples. + + Args: + dataset: Torch dataset. + + Returns: + List of samples. + """ + return await self.engine.batch_infer(dataset) diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py index f0ebb00af..18f14f1b6 100644 --- a/src/llamafactory/v1/core/data_engine.py +++ b/src/llamafactory/v1/core/data_engine.py @@ -14,15 +14,23 @@ """The definition of data engine. -Init Data engine: +How to use: +data_engine = DataEngine(data_args) +data_engine[i]: Get the sample via index. + +Init workflow: 1. Parse dataset info from arguments. 2. Load datasets according to dataset info. 3. Build data index (and reweight samples if necessary). -Get Data Sample: +Get data sample: 1. Get sample from data index. 2. Convert sample to standard format. 3. Return sample. + +Note: +1. The data engine is equivalent to the torch dataset. +2. The data engine is agnostic to the model used. """ import os @@ -98,10 +106,10 @@ class DataEngine(Dataset): size = self.dataset_infos[dataset_name].get("size") weight = self.dataset_infos[dataset_name].get("weight") - if size or weight: # data index plugin - from ..plugins.data_plugins.loader import DataIndexPlugin + if size or weight: + from ..plugins.data_plugins.loader import adjust_data_index - data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight) + data_index = adjust_data_index(data_index, size, weight) self.data_index.extend(data_index) @@ -150,9 +158,9 @@ class DataEngine(Dataset): dataset_name, sample_index = self.data_index[index] return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) else: # data selector plugin - from ..plugins.data_plugins.loader import DataSelectorPlugin + from ..plugins.data_plugins.loader import select_data_sample - selected_index = DataSelectorPlugin().select(self.data_index, index) + selected_index = select_data_sample(self.data_index, index) if isinstance(selected_index, list): return [ self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name) diff --git a/src/llamafactory/v1/core/model_loader.py b/src/llamafactory/v1/core/model_engine.py similarity index 85% rename from src/llamafactory/v1/core/model_loader.py rename to src/llamafactory/v1/core/model_engine.py index 069292274..f69b4c203 100644 --- a/src/llamafactory/v1/core/model_loader.py +++ b/src/llamafactory/v1/core/model_engine.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The definition of model loader. +"""The definition of model engine. How to use: -model_loader = ModelLoader(model_args, is_trainable=True) -model_loader.processor: Get the tokenizer or multi-modal processor. -model_loader.model_config: Get the model configuration. -model_loader.model: Get the HF model. +model_engine = ModelEngine(model_args, is_train=True) +model_engine.processor: Get the tokenizer or multi-modal processor. +model_engine.renderer: Get the renderer. +model_engine.model_config: Get the model configuration. +model_engine.model: Get the HF model. -Init Workflow: +Init workflow: 1. Init processor. +2. Init render. 2. Init model config. 3. Init model. 4. Init adapter. @@ -36,17 +38,18 @@ from ..accelerator.interface import DistributedInterface from ..config.model_args import ModelArguments, ModelClass from ..utils import logging from ..utils.types import HFConfig, HFModel, Processor +from .utils.rendering import Renderer logger = logging.get_logger(__name__) -class ModelLoader: - """Model loader. +class ModelEngine: + """Model engine. Args: model_args: Model arguments. - is_trainable: Whether to train the model. + is_train: Whether to train the model. """ def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None: @@ -56,6 +59,8 @@ class ModelLoader: """Whether to train the model.""" self.processor = self._init_processor() """Tokenizer or multi-modal processor.""" + self.renderer = Renderer(self.args.template, self.processor) + """Renderer.""" self.model_config = self._init_model_config() """Model configuration.""" self.model = self._init_model() @@ -107,7 +112,7 @@ class ModelLoader: init_device = InitPlugin(self.args.init_config.name)() else: - init_device = DistributedInterface().current_accelerator + init_device = DistributedInterface().current_device if init_device.type == DeviceType.META: with init_empty_weights(): @@ -144,12 +149,12 @@ class ModelLoader: if __name__ == "__main__": """ - python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5 + python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5 """ from ..config.arg_parser import get_args _, model_args, *_ = get_args() - model_loader = ModelLoader(model_args=model_args) - print(model_loader.processor) - print(model_loader.model_config) - print(model_loader.model) + model_engine = ModelEngine(model_args=model_args) + print(model_engine.processor) + print(model_engine.model_config) + print(model_engine.model) diff --git a/src/llamafactory/v1/core/trainer_utils/__init__.py b/src/llamafactory/v1/core/utils/__init__.py similarity index 100% rename from src/llamafactory/v1/core/trainer_utils/__init__.py rename to src/llamafactory/v1/core/utils/__init__.py diff --git a/src/llamafactory/v1/core/trainer_utils/callback.py b/src/llamafactory/v1/core/utils/callback.py similarity index 100% rename from src/llamafactory/v1/core/trainer_utils/callback.py rename to src/llamafactory/v1/core/utils/callback.py diff --git a/src/llamafactory/v1/core/trainer_utils/data_collator.py b/src/llamafactory/v1/core/utils/data_collator.py similarity index 100% rename from src/llamafactory/v1/core/trainer_utils/data_collator.py rename to src/llamafactory/v1/core/utils/data_collator.py diff --git a/src/llamafactory/v1/core/trainer_utils/data_loader.py b/src/llamafactory/v1/core/utils/data_loader.py similarity index 100% rename from src/llamafactory/v1/core/trainer_utils/data_loader.py rename to src/llamafactory/v1/core/utils/data_loader.py diff --git a/src/llamafactory/v1/core/trainer_utils/lr_scheduler.py b/src/llamafactory/v1/core/utils/lr_scheduler.py similarity index 100% rename from src/llamafactory/v1/core/trainer_utils/lr_scheduler.py rename to src/llamafactory/v1/core/utils/lr_scheduler.py diff --git a/src/llamafactory/v1/core/utils/rendering.py b/src/llamafactory/v1/core/utils/rendering.py new file mode 100644 index 000000000..20460c2d5 --- /dev/null +++ b/src/llamafactory/v1/core/utils/rendering.py @@ -0,0 +1,239 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re + +from ...utils.constants import IGNORE_INDEX +from ...utils.helper import get_tokenizer +from ...utils.types import Message, ModelInput, Processor + + +def _update_model_input( + processor: Processor, + input_ids: list[int], + labels: list[int], + loss_weights: list[int], + temp_str: str, + temp_weight: float, +) -> str: + """Update model input with temporary string.""" + if not temp_str: + return "" + + tokenizer = get_tokenizer(processor) + temp_ids = tokenizer.encode(temp_str, add_special_tokens=False) + input_ids.extend(temp_ids) + loss_weights.extend([temp_weight] * len(temp_ids)) + if temp_weight > 1e-6: + labels.extend(temp_ids) + else: + labels.extend([IGNORE_INDEX] * len(temp_ids)) + + return "" + + +def render_chatml_messages( + processor: Processor, + messages: list[Message], + tools: str | None = None, + is_generate: bool = False, +) -> ModelInput: + """Apply chatml template to messages and convert them to model input. + + See https://huggingface.co/spaces/huggingfacejs/chat-template-playground + """ + input_ids, labels, loss_weights = [], [], [] + temp_str, temp_weight = "", 0.0 + if tools: + temp_str += "<|im_start|>system\n" + if messages[0]["role"] == "system": + for content in messages[0]["content"]: + if content["type"] == "text": + temp_str += content["value"] + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + temp_str += "\n\n" + temp_weight = messages[0].get("loss_weight", 0.0) + + temp_str += ( + "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n" + ) + try: + tools = json.loads(tools) + except json.JSONDecodeError: + raise ValueError(f"Invalid tools format: {str(tools)}.") + + if not isinstance(tools, list): + tools = [tools] + + for tool in tools: + temp_str += "\n" + json.dumps(tool, ensure_ascii=False) + + temp_str += ( + "\n\n\nFor each function call, return a json object with function name " + 'and arguments within XML tags:\n\n{"name": ' + ', "arguments": }\n<|im_end|>\n' + ) + elif messages[0]["role"] == "system": + temp_str += "<|im_start|>system\n" + for content in messages[0]["content"]: + if content["type"] == "text": + temp_str += content["value"] + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + temp_str += "<|im_end|>\n" + temp_weight = messages[0].get("loss_weight", 0.0) + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + + for turn_idx, message in enumerate(messages): + if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0): + temp_str += "<|im_start|>" + message["role"] + "\n" + for content in message["content"]: + if content["type"] == "text": + temp_str += content["value"] + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + temp_str += "<|im_end|>\n" + temp_weight = message.get("loss_weight", 0.0) + elif message["role"] == "assistant": + temp_str += "<|im_start|>" + message["role"] + "\n" + for val_idx, content in enumerate(message["content"]): + if content["type"] == "text": + temp_str += content["value"] + elif content["type"] == "reasoning": + temp_str += "\n" + content["value"] + "\n\n\n" # avoid using special tokens + elif content["type"] == "tool_call": + if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]: + temp_str += "\n" + + try: + tool_call = json.loads(content["value"]) + except json.JSONDecodeError: + raise ValueError(f"Invalid tool call format: {content['value']}.") + temp_str += ( + '\n{"name": "' + + tool_call["name"] + + '", "arguments": ' + + json.dumps(tool_call["arguments"], ensure_ascii=False) + + "}\n" + ) + + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + temp_str += "<|im_end|>\n" + temp_weight = message.get("loss_weight", 1.0) + elif message["role"] == "tool": + if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool": + temp_str += "<|im_start|>user" + + temp_str += "\n\n" + for content in message["content"]: + if content["type"] == "text": + temp_str += content["value"] + else: + raise ValueError(f"Unsupported content type: {content['type']}") + + temp_str += "\n" + if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool": + temp_str += "<|im_end|>\n" + + temp_weight = message.get("loss_weight", 0.0) + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + + if is_generate: + temp_str += "<|im_start|>assistant\n" + temp_weight = 0.0 + + temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) + + attention_mask = [1] * len(input_ids) + return ModelInput( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels, + loss_weights=loss_weights, + ) + + +def parse_chatml_message(generated_text: str) -> Message: + """Parse a message in ChatML format. Supports interleaved reasoning and tool calls. + + Args: + generated_text (str): The generated text in ChatML format. + + Returns: + Message: The parsed message. + """ + pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*\s*", re.DOTALL) + content = [] + last_end = 0 + for match in pattern.finditer(generated_text): + start, end = match.span() + if start > last_end: + text = generated_text[last_end:start].strip() + if text: + content.append({"type": "text", "value": text}) + + tag_type = match.group(1) + tag_value = match.group(2).strip() + if tag_type == "thinking": + content.append({"type": "reasoning", "value": tag_value.strip()}) + elif tag_type == "tool_call": + try: + json.loads(tag_value.strip()) + except json.JSONDecodeError: + raise ValueError(f"Invalid tool call format: {tag_value.strip()}.") + + content.append({"type": "tool_call", "value": tag_value.strip()}) + + last_end = end + + if last_end < len(generated_text): + text = generated_text[last_end:].strip() + if text: + content.append({"type": "text", "value": text}) + + return Message(role="assistant", content=content) + + +class Renderer: + def __init__(self, template: str, processor: Processor): + self.template = template + self.processor = processor + + def render_messages( + self, messages: list[Message], tools: str | None = None, is_generate: bool = False + ) -> ModelInput: + if self.template == "chatml": + return render_chatml_messages(self.processor, messages, tools, is_generate) + else: + from ...plugins.model_plugins.rendering import RenderingPlugin + + return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate) + + def parse_message(self, generated_text: str) -> Message: + if self.template == "chatml": + return parse_chatml_message(generated_text) + else: + from ...plugins.model_plugins.rendering import RenderingPlugin + + return RenderingPlugin(self.template).parse_message(generated_text) diff --git a/src/llamafactory/v1/launcher.py b/src/llamafactory/v1/launcher.py index 3d18286bc..9f04ab398 100644 --- a/src/llamafactory/v1/launcher.py +++ b/src/llamafactory/v1/launcher.py @@ -49,6 +49,11 @@ def launch(): run_sft() + elif command == "chat": + from .samplers.cli_sampler import run_chat + + run_chat() + elif command == "env": print_env() diff --git a/src/llamafactory/v1/plugins/data_plugins/converter.py b/src/llamafactory/v1/plugins/data_plugins/converter.py index 07aae2dfe..ca7b4fa98 100644 --- a/src/llamafactory/v1/plugins/data_plugins/converter.py +++ b/src/llamafactory/v1/plugins/data_plugins/converter.py @@ -13,11 +13,12 @@ # limitations under the License. +import json from typing import Any, Literal, NotRequired, TypedDict from ...utils import logging from ...utils.plugin import BasePlugin -from ...utils.types import DPOSample, Sample, SFTSample +from ...utils.types import DPOSample, Sample, SFTSample, ToolCall logger = logging.get_logger(__name__) @@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin): return super().__call__(raw_sample) -@DataConverterPlugin("alpaca").register +@DataConverterPlugin("alpaca").register() def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: """Convert Alpaca sample to SFT sample. @@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: return {"messages": messages} -@DataConverterPlugin("sharegpt").register +@DataConverterPlugin("sharegpt").register() def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: """Convert ShareGPT sample to SFT sample. @@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: "function_call": "assistant", } messages = [] - tools = raw_sample.get("tools", "") + tools = raw_sample.get("tools") + if tools: + try: + tools: list[dict[str, Any]] = json.loads(tools) + except json.JSONDecodeError: + logger.warning_rank0(f"Invalid tools format: {str(tools)}") + tools = [] for message in raw_sample.get("conversations", []): tag = message["from"] if tag not in tag_mapping: logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}") elif tag == "function_call": + try: + tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"]) + except json.JSONDecodeError: + logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}") + continue + + if not isinstance(tool_calls, list): + tool_calls = [tool_calls] + messages.append( { "role": "assistant", - "content": [{"type": "tool_calls", "value": message["value"]}], + "content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls], "loss_weight": 1.0, } ) @@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: ) if tools: - if messages and messages[0]["role"] == "system": - messages[0]["content"].append({"type": "tools", "value": tools}) - else: - messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0}) - - return {"messages": messages} + return {"messages": messages, "extra_info": json.dumps({"tools": tools})} + else: + return {"messages": messages} -@DataConverterPlugin("pair").register +@DataConverterPlugin("pair").register() def pair_converter(raw_sample: PairSample) -> DPOSample: """Convert Pair sample to DPO sample. diff --git a/src/llamafactory/v1/plugins/data_plugins/loader.py b/src/llamafactory/v1/plugins/data_plugins/loader.py index 9200ec34a..eb24ca614 100644 --- a/src/llamafactory/v1/plugins/data_plugins/loader.py +++ b/src/llamafactory/v1/plugins/data_plugins/loader.py @@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", " raise ValueError(f"Unknown dataset filetype: {filetype}.") -@DataLoaderPlugin("local").register +@DataLoaderPlugin("local").register() def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset: if os.path.isdir(filepath): filetype = _get_builder_name(os.listdir(filepath)[0]) @@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset return dataset -class DataIndexPlugin(BasePlugin): - """Plugin for adjusting dataset index.""" +def adjust_data_index( + data_index: list[tuple[str, int]], size: int | None, weight: float | None +) -> list[tuple[str, int]]: + """Adjust dataset index by size and weight. - def adjust_data_index( - self, data_index: list[tuple[str, int]], size: int | None, weight: float | None - ) -> list[tuple[str, int]]: - """Adjust dataset index by size and weight. + Args: + data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). + size (Optional[int]): Desired dataset size. + weight (Optional[float]): Desired dataset weight. - Args: - data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). - size (Optional[int]): Desired dataset size. - weight (Optional[float]): Desired dataset weight. + Returns: + list[tuple[str, int]]: Adjusted dataset index. + """ + if size is not None: + data_index = random.choices(data_index, k=size) - Returns: - list[tuple[str, int]]: Adjusted dataset index. - """ - if size is not None: - data_index = random.choices(data_index, k=size) + if weight is not None: + data_index = random.choices(data_index, k=int(len(data_index) * weight)) - if weight is not None: - data_index = random.choices(data_index, k=int(len(data_index) * weight)) - - return data_index + return data_index -class DataSelectorPlugin(BasePlugin): - """Plugin for selecting dataset samples.""" +def select_data_sample( + data_index: list[tuple[str, int]], index: slice | list[int] | Any +) -> tuple[str, int] | list[tuple[str, int]]: + """Select dataset samples. - def select( - self, data_index: list[tuple[str, int]], index: slice | list[int] | Any - ) -> tuple[str, int] | list[tuple[str, int]]: - """Select dataset samples. + Args: + data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). + index (Union[slice, list[int], Any]): Index of dataset samples. - Args: - data_index (list[tuple[str, int]]): List of (dataset_name, sample_index). - index (Union[slice, list[int], Any]): Index of dataset samples. - - Returns: - Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples. - """ - if isinstance(index, slice): - return [data_index[i] for i in range(*index.indices(len(data_index)))] - elif isinstance(index, list): - return [data_index[i] for i in index] - else: - raise ValueError(f"Invalid index type {type(index)}.") + Returns: + Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples. + """ + if isinstance(index, slice): + return [data_index[i] for i in range(*index.indices(len(data_index)))] + elif isinstance(index, list): + return [data_index[i] for i in index] + else: + raise ValueError(f"Invalid index type {type(index)}.") diff --git a/src/llamafactory/v1/plugins/data_plugins/template.py b/src/llamafactory/v1/plugins/data_plugins/template.py deleted file mode 100644 index 96159142e..000000000 --- a/src/llamafactory/v1/plugins/data_plugins/template.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from dataclasses import dataclass - - -@dataclass -class Template: - user_template: str - assistant_template: str - system_template: str - - def render_message(self, message: dict[str, str]) -> str: - return self.user_template.format(**message) - - -@dataclass -class QwenTemplate: - message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool - thinking_template: str = "\n{content}\n\n\n" - - def _extract_content(self, content_data: str | list[dict[str, str]]) -> str: - if isinstance(content_data, str): - return content_data.strip() - - if isinstance(content_data, list): - parts = [] - for item in content_data: - if item.get("type") == "text": - parts.append(item.get("value", "")) - elif item.get("type") == "image_url": - pass - return "\n".join(parts).strip() - - return "" - - def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str: - role = message["role"] - content = self._extract_content(message.get("content", "")) - - if role == "assistant": - reasoning_content = message.get("reasoning_content", "") - if reasoning_content: - reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip()) - return self.message_template.format(role="assistant", content=reasoning_content + content) - else: - return self.message_template.format(role=role, content=content) - - def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any: - """Encode one message.""" - input_ids, attention_mask, labels = [], [], [] - for message in messages: - content_str = self.render_message(message) - content_ids = tokenizer.encode(content_str, add_special_tokens=False) - input_ids += content_ids - attention_mask += [1] * len(content_ids) - - if hasattr(message, "loss_weight"): - loss_weight = message["loss_weight"] - else: - loss_weight = 1 if message["role"] == "assistant" else 0 - if loss_weight == 1: - labels += content_ids - else: - labels += [-100] * len(content_ids) - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} - model_inputs.update({"position_ids": list(range(len(input_ids)))}) - model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()} - return model_inputs - - -if __name__ == "__main__": - - def to_qwen3_messages(template: QwenTemplate, messages: list[dict]): - out = [] - for m in messages: - role = m["role"] - content = template._extract_content(m.get("content", "")) - if role == "assistant": - reasoning = (m.get("reasoning_content") or "").strip() - if reasoning: - content = template.thinking_template.format(content=reasoning) + content - out.append({"role": role, "content": content}) - return out - - from transformers import AutoTokenizer - - tok = AutoTokenizer.from_pretrained( - "Qwen/Qwen3-30B-A3B-Thinking-2507", - trust_remote_code=True, - ) - - test_messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - { - "role": "user", - "content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}], - }, - { - "role": "assistant", - "reasoning_content": "这是一个简单的数学问题。1加1的结果是2。", - "content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}], - }, - ] - - template = QwenTemplate() - rendered_custom = "".join([template.render_message(m) for m in test_messages]) - - qwen3_messages = to_qwen3_messages(template, test_messages) - rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False) - - print("==== custom ====") - print(rendered_custom) - print("==== hf ====") - print(rendered_hf) - - assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch" - - ids_custom = tok.encode(rendered_custom, add_special_tokens=False) - ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False) - assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}" diff --git a/src/llamafactory/v1/plugins/model_plugins/initialization.py b/src/llamafactory/v1/plugins/model_plugins/initialization.py index 5e6c8bb99..efb8f22f7 100644 --- a/src/llamafactory/v1/plugins/model_plugins/initialization.py +++ b/src/llamafactory/v1/plugins/model_plugins/initialization.py @@ -25,12 +25,12 @@ class InitPlugin(BasePlugin): return super().__call__() -@InitPlugin("init_on_meta").register +@InitPlugin("init_on_meta").register() def init_on_meta() -> torch.device: return torch.device(DeviceType.META.value) -@InitPlugin("init_on_rank0").register +@InitPlugin("init_on_rank0").register() def init_on_rank0() -> torch.device: if DistributedInterface().get_rank() == 0: return torch.device(DeviceType.CPU.value) @@ -38,6 +38,6 @@ def init_on_rank0() -> torch.device: return torch.device(DeviceType.META.value) -@InitPlugin("init_on_default").register +@InitPlugin("init_on_default").register() def init_on_default() -> torch.device: - return DistributedInterface().current_accelerator + return DistributedInterface().current_device diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/base.py b/src/llamafactory/v1/plugins/model_plugins/kernels/base.py index d5cd83be6..265986ccc 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/base.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/base.py @@ -38,17 +38,17 @@ class BaseKernel(ABC): @classmethod def get_kernel_id(cls) -> str: - r"""Returns the unique identifier for the kernel.""" + """Returns the unique identifier for the kernel.""" return cls._kernel_id @classmethod def get_device(cls) -> str: - r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu").""" + """Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu").""" return cls._device @classmethod def check_deps(cls) -> bool: - r"""Checks if the required dependencies for the kernel are available. + """Checks if the required dependencies for the kernel are available. Returns: bool: ``True`` if dependencies are met, ``False`` otherwise. @@ -65,7 +65,7 @@ class BaseKernel(ABC): @classmethod @abstractmethod def apply(cls, **kwargs) -> HFModel: - r"""Applies the kernel optimization to the model. + """Applies the kernel optimization to the model. Args: **kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration. diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py index 19a2def19..446cd57a6 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py @@ -33,7 +33,7 @@ logger = get_logger(__name__) def scan_all_kernels(): - r"""Scan all kernels in the ``ops`` directory. + """Scan all kernels in the ``ops`` directory. Scans the ``ops`` directory for all ``.py`` files and attempts to import them. Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels. @@ -77,7 +77,7 @@ default_kernels = scan_all_kernels() def get_default_kernels(): - r"""Get a list of default registered kernel IDs. + """Get a list of default registered kernel IDs. Returns: list[str]: List of kernel IDs. @@ -86,7 +86,7 @@ def get_default_kernels(): def apply_kernel(kernel_id: str, **kwargs): - r"""Applies a specific kernel to the model. + """Applies a specific kernel to the model. Args: kernel_id (str): The ID of the kernel to apply. @@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs): kernel = default_kernels.get(kernel_id) if kernel is None: raise ValueError(f"Kernel {kernel_id} not found") + kernel.apply(**kwargs) class KernelPlugin(BasePlugin): - r"""Plugin for managing kernel optimizations.""" + """Plugin for managing kernel optimizations.""" pass -@KernelPlugin("auto").register +@KernelPlugin("auto").register() def apply_default_kernels(**kwargs): - r"""Applies all default registered kernels to the model. + """Applies all default registered kernels to the model. Args: **kwargs: Keyword arguments passed to the kernel application function. @@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs): use_kernels = default_kernels.keys() else: use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3" + for kernel in use_kernels: if kernel not in default_kernels: raise ValueError(f"Kernel {kernel} not found") + apply_kernel(kernel, **kwargs) + return kwargs.get("model") diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py index 0d84dbec8..62854777d 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py @@ -40,11 +40,11 @@ from ...registry import register_kernel class GmmFunction(torch.autograd.Function): - r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM).""" + """Custom autograd function for NPU Grouped Matrix Multiplication (GMM).""" @staticmethod def forward(ctx, x, weight, group_list): - r"""Performs the forward pass of Grouped Matrix Multiplication. + """Performs the forward pass of Grouped Matrix Multiplication. Args: ctx: Context object to save tensors for backward pass. @@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): - r"""Performs the backward pass of Grouped Matrix Multiplication. + """Performs the backward pass of Grouped Matrix Multiplication. Args: ctx: Context object containing saved tensors. @@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function): class HybridGmmFunction(torch.autograd.Function): - r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU.""" + """Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU.""" @staticmethod def forward(ctx, num_experts, *args): - r"""Performs the forward pass of Hybrid GMM. + """Performs the forward pass of Hybrid GMM. Args: ctx: Context object to save tensors. @@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function): @staticmethod def backward(ctx, *grad_outputs): - r"""Performs the backward pass of Hybrid GMM. + """Performs the backward pass of Hybrid GMM. Args: ctx: Context object containing saved tensors. @@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function): class NpuMoeFused: - r"""Container for NPU fused MoE forward functions.""" + """Container for NPU fused MoE forward functions.""" @staticmethod def npu_moe_experts_forward( self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor ) -> torch.Tensor: - r"""Forward pass for MoE experts using NPU fused operations. + """Forward pass for MoE experts using NPU fused operations. Args: self: The MoE layer instance. @@ -230,11 +230,11 @@ class NpuMoeFused: class Qwen3NpuMoeFused: - r"""Container for Qwen3 NPU fused MoE forward functions.""" + """Container for Qwen3 NPU fused MoE forward functions.""" @staticmethod def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor): - r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations. + """Forward pass for Qwen3 sparse MoE block using NPU fused operations. Args: self: The Qwen3 MoE block instance. @@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"): @register_kernel class NpuFusedMoEKernel(BaseKernel): - r"""NPU Fused MoE Kernel implementation.""" + """NPU Fused MoE Kernel implementation.""" _kernel_id = "npu_fused_moe" _device = DeviceType.NPU @classmethod def apply(cls, **kwargs) -> HFModel: - r"""Applies the NPU fused MoE kernel to the model. + """Applies the NPU fused MoE kernel to the model. Args: **kwargs: Keyword arguments containing the model. @@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel): if target_moe_mapping is None: return model + for module in model.modules(): class_name = module.__class__.__name__ if class_name in target_moe_mapping: diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py index e6f82051c..a45077bc0 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py @@ -38,7 +38,7 @@ except ImportError: def npu_swiglu_forward(self, hidden_state): - r"""SwiGLU forward pass for NPU. + """SwiGLU forward pass for NPU. Args: self: The MLP layer instance. @@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state): def _npu_swiglu_glm4_forward(self, hidden_states): - r"""SwiGLU forward pass for GLM4 on NPU. + """SwiGLU forward pass for GLM4 on NPU. Args: self: The GLM4 MLP layer instance. @@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states): def _npu_swiglu_gemma3ntext_forward(self, hidden_states): - r"""SwiGLU forward pass for Gemma3nText on NPU. + """SwiGLU forward pass for Gemma3nText on NPU. Args: self: The Gemma3nText MLP layer instance. @@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states): @register_kernel class NpuSwiGluKernel(BaseKernel): - r"""NPU Kernel for fused SwiGLU activation.""" + """NPU Kernel for fused SwiGLU activation.""" # just support apply to the following module layers expect_modules = frozenset( @@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel): @classmethod def apply(cls, **kwargs) -> "HFModel": - r"""Applies the NPU fused SwiGLU kernel to the model. + """Applies the NPU fused SwiGLU kernel to the model. Args: **kwargs: Keyword arguments containing the model. diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py index 6ce36bb67..35057f451 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py @@ -30,7 +30,7 @@ from ...registry import register_kernel def npu_rms_norm_forward(self, hidden_states): - r"""NPU forward implementation for RMSNorm. + """NPU forward implementation for RMSNorm. Args: self: RMSNorm module instance with `weight` and `variance_epsilon`. @@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states): @register_kernel class NpuRMSNormKernel(BaseKernel): - r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" + """NPU kernel wrapper for RMSNorm that applies the replacement within a model.""" _kernel_id = "npu_fused_rmsnorm" _device = DeviceType.NPU @classmethod def apply(cls, **kwargs) -> "HFModel": - r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. + """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules. Key points: - Match modules whose class name contains "RMSNorm" (case-insensitive). @@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel): if not cls.check_deps(): raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.") + rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE) for name, module in model.named_modules(): diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py index b431b5063..2f3e290a0 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py @@ -40,7 +40,7 @@ except ImportError: def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization. + """Applies Rotary Position Embedding to the query and key tensors using NPU optimization. Args: q (Tensor): Query tensor. @@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1): - r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU. + """Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU. Args: q (Tensor): Query tensor. @@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un @register_kernel class NpuRoPEKernel(BaseKernel): - r"""NPU Kernel for Rotary Position Embedding.""" + """NPU Kernel for Rotary Position Embedding.""" _kernel_id = "npu_fused_rope" _device = DeviceType.NPU @classmethod def apply(cls, **kwargs) -> "HFModel": - r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. + """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. This function iterates through the model's modules to find attention layers, identifies the module where they are defined, and replaces the original @@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel): """ if not cls.check_deps(): raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.") + model = kwargs.get("model", None) if model is None: raise ValueError(f"HFModel instance is required for {cls.__name__}.") + _modules = set() for module in model.modules(): if "Attention" in module.__class__.__name__: @@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel): _modules.add(module_name) except Exception as e: logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}") + return model diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index f6c1984ae..16a5c8e0f 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"] class Registry: - r"""Registry for managing kernel implementations. + """Registry for managing kernel implementations. Storage structure: ``{ "kernel_id": Class }`` """ @@ -38,8 +38,8 @@ class Registry: _kernels: dict[str, type[BaseKernel]] = {} @classmethod - def register(cls, kernel_cls: type[BaseKernel]): - r"""Decorator to register a kernel class. + def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None: + """Decorator to register a kernel class. The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes. @@ -47,7 +47,7 @@ class Registry: kernel_cls (type[BaseKernel]): The kernel class to register. Returns: - type[BaseKernel]: The registered kernel class. + type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator Raises: TypeError: If the class does not inherit from :class:`BaseKernel`. @@ -55,6 +55,7 @@ class Registry: """ if not issubclass(kernel_cls, BaseKernel): raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel") + kernel_id = kernel_cls.get_kernel_id() device = kernel_cls.get_device() @@ -73,7 +74,7 @@ class Registry: @classmethod def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]: - r"""Retrieves a registered kernel implementation by its ID. + """Retrieves a registered kernel implementation by its ID. Args: kernel_id (str): The ID of the kernel to retrieve. @@ -85,7 +86,7 @@ class Registry: @classmethod def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]: - r"""Returns a dictionary of all registered kernels. + """Returns a dictionary of all registered kernels. Returns: dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes. diff --git a/src/llamafactory/v1/plugins/model_plugins/peft.py b/src/llamafactory/v1/plugins/model_plugins/peft.py index 819b06d14..a3f37482c 100644 --- a/src/llamafactory/v1/plugins/model_plugins/peft.py +++ b/src/llamafactory/v1/plugins/model_plugins/peft.py @@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin): return super().__call__(model, config) -@PeftPlugin("lora").register +@PeftPlugin("lora").register() def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel: peft_config = LoraConfig(**config) model = get_peft_model(model, peft_config) return model -@PeftPlugin("freeze").register +@PeftPlugin("freeze").register() def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel: raise NotImplementedError() diff --git a/src/llamafactory/v1/plugins/model_plugins/rendering.py b/src/llamafactory/v1/plugins/model_plugins/rendering.py new file mode 100644 index 000000000..7e9d026d2 --- /dev/null +++ b/src/llamafactory/v1/plugins/model_plugins/rendering.py @@ -0,0 +1,36 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ...utils.plugin import BasePlugin +from ...utils.types import Message, ModelInput, Processor + + +class RenderingPlugin(BasePlugin): + pass + + +@RenderingPlugin("qwen").register("render_messages") +def render_qwen_messages( + processor: Processor, + messages: list[Message], + tools: str | None = None, + is_generate: bool = False, +) -> ModelInput: + raise NotImplementedError() + + +@RenderingPlugin("qwen").register("parse_message") +def parse_qwen_message(generated_text: str) -> Message: + raise NotImplementedError() diff --git a/src/llamafactory/v1/samplers/cli_sampler.py b/src/llamafactory/v1/samplers/cli_sampler.py index c08bc838a..d102ac84b 100644 --- a/src/llamafactory/v1/samplers/cli_sampler.py +++ b/src/llamafactory/v1/samplers/cli_sampler.py @@ -12,10 +12,64 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import os +from collections.abc import Generator +from threading import Thread -from ..config import InputArgument, SampleBackend, get_args +from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args from ..core.base_sampler import BaseSampler -from ..core.model_loader import ModelLoader +from ..core.data_engine import DataEngine +from ..core.model_engine import ModelEngine +from ..core.utils.rendering import Renderer +from ..utils.types import HFModel, Message, Sample, TorchDataset + + +class SyncSampler(BaseSampler): + def __init__( + self, + args: SampleArguments, + model_args: ModelArguments, + model: HFModel, + renderer: Renderer, + ) -> None: + def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None: + asyncio.set_event_loop(loop) + loop.run_forever() + + super().__init__(args, model_args, model, renderer) + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) + self._thread.start() + + def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]: + """Generate tokens synchronously. + + Args: + messages: List of messages. + tools: Tools string. + + Yields: + Generated tokens. + """ + generator = super().generate(messages, tools) + while True: + try: + token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result() + yield token + except StopAsyncIteration: + break + + def batch_infer(self, dataset: TorchDataset) -> list[Sample]: + """Batch infer samples synchronously. + + Args: + dataset: Torch dataset. + + Returns: + List of samples. + """ + return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result() def run_chat(args: InputArgument = None): @@ -23,12 +77,48 @@ def run_chat(args: InputArgument = None): if sample_args.sample_backend != SampleBackend.HF: model_args.init_plugin = {"name": "init_on_meta"} - model_loader = ModelLoader(model_args) - sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor) + model_engine = ModelEngine(model_args) + sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer) if data_args.dataset is not None: - sampler.batch_infer() + dataset = DataEngine(data_args) + sampler.batch_infer(dataset) else: - sampler.generate() + if os.name != "nt": + try: + import readline # noqa: F401 + except ImportError: + print("Install `readline` for a better experience.") + + messages = [] + print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") + + while True: + try: + query = input("\nUser: ") + except UnicodeDecodeError: + print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") + continue + except Exception: + raise + + if query.strip() == "exit": + break + + if query.strip() == "clear": + messages = [] + print("History has been removed.") + continue + + messages.append({"role": "user", "content": [{"type": "text", "value": query}]}) + print("Assistant: ", end="", flush=True) + + response = "" + for new_text in sampler.generate(messages): + print(new_text, end="", flush=True) + response += new_text + + print() + messages.append(model_engine.renderer.parse_message(response)) if __name__ == "__main__": diff --git a/src/llamafactory/v1/trainers/sft_trainer.py b/src/llamafactory/v1/trainers/sft_trainer.py index c48c1092f..2cb8d3a3c 100644 --- a/src/llamafactory/v1/trainers/sft_trainer.py +++ b/src/llamafactory/v1/trainers/sft_trainer.py @@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface from ..config.arg_parser import get_args from ..core.base_trainer import BaseTrainer from ..core.data_engine import DataEngine -from ..core.model_loader import ModelLoader +from ..core.model_engine import ModelEngine class SFTTrainer(BaseTrainer): @@ -28,11 +28,11 @@ def run_sft(user_args): model_args, data_args, training_args, _ = get_args(user_args) DistributedInterface(training_args.dist_config) data_engine = DataEngine(data_args) - model_loader = ModelLoader(model_args) + model_engine = ModelEngine(model_args) trainer = SFTTrainer( args=training_args, - model=model_loader.model, - processor=model_loader.processor, + model=model_engine.model, + processor=model_engine.processor, dataset=data_engine, ) trainer.fit() diff --git a/src/llamafactory/v1/utils/constants.py b/src/llamafactory/v1/utils/constants.py index ec0d62554..9ec68b44d 100644 --- a/src/llamafactory/v1/utils/constants.py +++ b/src/llamafactory/v1/utils/constants.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +IGNORE_INDEX = -100 diff --git a/src/llamafactory/v1/utils/helper.py b/src/llamafactory/v1/utils/helper.py new file mode 100644 index 000000000..7cbd9f336 --- /dev/null +++ b/src/llamafactory/v1/utils/helper.py @@ -0,0 +1,29 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import PreTrainedTokenizer + +from .types import Processor + + +def get_tokenizer(processor: Processor) -> PreTrainedTokenizer: + """Get tokenizer from processor. + + Args: + processor: Processor. + + Returns: + Tokenizer. + """ + return processor.tokenizer if hasattr(processor, "tokenizer") else processor diff --git a/src/llamafactory/v1/utils/logging.py b/src/llamafactory/v1/utils/logging.py index 81bc53751..4d38927ff 100644 --- a/src/llamafactory/v1/utils/logging.py +++ b/src/llamafactory/v1/utils/logging.py @@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level": def _get_library_name() -> str: - return __name__.split(".")[0] + return ".".join(__name__.split(".")[:2]) # llamafactory.v1 def _get_library_root_logger() -> "_Logger": diff --git a/src/llamafactory/v1/utils/plugin.py b/src/llamafactory/v1/utils/plugin.py index 0e4bcdf8e..136896492 100644 --- a/src/llamafactory/v1/utils/plugin.py +++ b/src/llamafactory/v1/utils/plugin.py @@ -13,6 +13,7 @@ # limitations under the License. +from collections import defaultdict from collections.abc import Callable from . import logging @@ -27,7 +28,7 @@ class BasePlugin: A plugin is a callable object that can be registered and called by name. """ - _registry: dict[str, Callable] = {} + _registry: dict[str, dict[str, Callable]] = defaultdict(dict) def __init__(self, name: str | None = None): """Initialize the plugin with a name. @@ -37,8 +38,7 @@ class BasePlugin: """ self.name = name - @property - def register(self): + def register(self, method_name: str = "__call__"): """Decorator to register a function as a plugin. Example usage: @@ -46,16 +46,21 @@ class BasePlugin: @PrintPlugin("hello").register() def print_hello(): print("Hello world!") + + + @PrintPlugin("hello").register("again") + def print_hello_again(): + print("Hello world! Again.") ``` """ if self.name is None: - raise ValueError("Plugin name is not specified.") + raise ValueError("Plugin name should be specified.") - if self.name in self._registry: - logger.warning_rank0_once(f"Plugin {self.name} is already registered.") + if method_name in self._registry[self.name]: + logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.") def decorator(func: Callable) -> Callable: - self._registry[self.name] = func + self._registry[self.name][method_name] = func return func return decorator @@ -68,10 +73,23 @@ class BasePlugin: PrintPlugin("hello")() ``` """ - if self.name not in self._registry: - raise ValueError(f"Plugin {self.name} is not registered.") + if "__call__" not in self._registry[self.name]: + raise ValueError(f"Method __call__ of plugin {self.name} is not registered.") - return self._registry[self.name](*args, **kwargs) + return self._registry[self.name]["__call__"](*args, **kwargs) + + def __getattr__(self, method_name: str): + """Get the registered function with the given name. + + Example usage: + ```python + PrintPlugin("hello").again() + ``` + """ + if method_name not in self._registry[self.name]: + raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.") + + return self._registry[self.name][method_name] if __name__ == "__main__": @@ -82,8 +100,13 @@ if __name__ == "__main__": class PrintPlugin(BasePlugin): pass - @PrintPlugin("hello").register + @PrintPlugin("hello").register() def print_hello(): print("Hello world!") + @PrintPlugin("hello").register("again") + def print_hello_again(): + print("Hello world! Again.") + PrintPlugin("hello")() + PrintPlugin("hello").again() diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py index bae334f4a..cad3d57a1 100644 --- a/src/llamafactory/v1/utils/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -84,27 +84,59 @@ class DistributedConfig(TypedDict, total=False): class Content(TypedDict): - type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"] + type: Literal["text", "reasoning", "tool_call", "image_url"] + """Type of the content.""" value: str + """Value of the content.""" class Message(TypedDict): role: Literal["system", "user", "assistant", "tool"] + """Role of the message.""" content: list[Content] - loss_weight: float + """Content of the message.""" + loss_weight: NotRequired[float] + """Loss weight for this message, default to 1.0. Required in training.""" class SFTSample(TypedDict): messages: list[Message] + """Messages in the sample.""" extra_info: NotRequired[str] + """Extra information for the sample, including tools, kto_labels.""" _dataset_name: NotRequired[str] + """Dataset name for the sample.""" class DPOSample(TypedDict): chosen_messages: list[Message] + """Chosen messages in the sample.""" rejected_messages: list[Message] + """Rejected messages in the sample.""" extra_info: NotRequired[str] + """Extra information for the sample, including tools, kto_labels.""" _dataset_name: NotRequired[str] + """Dataset name for the sample.""" Sample = Union[SFTSample, DPOSample] + + +class ToolCall(TypedDict): + name: str + """Function name.""" + arguments: str + """Function arguments.""" + + +class ModelInput(TypedDict, total=False): + input_ids: list[int] + """Input ids for the model.""" + attention_mask: list[int] + """Attention mask for the model.""" + labels: list[int] + """Labels for the model.""" + loss_weights: list[float] + """Loss weight for each token, default to 1.0.""" + position_ids: NotRequired[list[int] | list[list[int]]] + """Position ids for the model (optional).""" diff --git a/tests/conftest.py b/tests/conftest.py index cd20a0d2b..835a15980 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,7 @@ Contains shared fixtures, pytest configuration, and custom markers. """ import os -from typing import Optional +import sys import pytest import torch @@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]): item.add_marker(skip_slow) -def _get_visible_devices_env() -> Optional[str]: +def _get_visible_devices_env() -> str | None: """Return device visibility env var name.""" if CURRENT_DEVICE == "cuda": return "CUDA_VISIBLE_DEVICES" @@ -149,6 +149,14 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) - devices_str = ",".join(str(i) for i in range(required)) monkeypatch.setenv(env_key, devices_str) + + # add project root dir to path for mp run + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + if project_root not in sys.path: + sys.path.insert(0, project_root) + + os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "") + else: # non-distributed test if old_value: visible_devices = [v for v in old_value.split(",") if v != ""] diff --git a/tests/version.txt b/tests/version.txt index 08eb069fe..53686d274 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.4.105 +0.9.5.101 diff --git a/tests_v1/core/test_data_loader.py b/tests_v1/core/test_data_loader.py deleted file mode 100644 index 329098242..000000000 --- a/tests_v1/core/test_data_loader.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2025 the LlamaFactory team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Integration tests for DataLoader with different combinations of packing and dynamic batching. - -Tests the 4 scenarios: -a) non pack + non dynamic. -b) non pack + dynamic. -c) pack + non dynamic. -d) pack + dynamic. -""" - -import torch -from torch.utils.data import DataLoader as TorchDataLoader -from torch.utils.data import Dataset -from transformers import AutoTokenizer - -from llamafactory.v1.config.data_args import DataArguments -from llamafactory.v1.core.data_engine import DataEngine -from llamafactory.v1.core.trainer_utils.data_collator import ( - DefaultCollator, -) -from llamafactory.v1.core.trainer_utils.data_loader import DataLoader -from llamafactory.v1.plugins.data_plugins.template import QwenTemplate -from llamafactory.v1.utils.batching_queue import TextBatchingQueue - - -class TensorDataset(Dataset): - """Wrapper dataset that converts DataEngine samples to tensor format.""" - - def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None): - self.data_engine = data_engine - self.processor = processor - self.template = template - self.max_samples = max_samples or len(data_engine) - self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor - - def __len__(self): - return min(self.max_samples, len(self.data_engine)) - - def __getitem__(self, idx): - # Get sample from DataEngine - sample = self.data_engine[idx] - - # Extract messages from sample - # DataEngine returns samples with format like {"messages": [...], ...} - # For llamafactory/v1-sft-demo, the format should have "messages" field - messages = None - if "messages" in sample: - messages = sample["messages"] - elif "conversations" in sample: - messages = sample["conversations"] - elif "conversation" in sample: - messages = sample["conversation"] - else: - # Try to find message-like fields (skip _dataset_name) - for key, value in sample.items(): - if key.startswith("_"): - continue - if isinstance(value, list) and len(value) > 0: - # Check if it looks like a message list - if isinstance(value[0], dict) and "role" in value[0]: - messages = value - break - - if messages is None: - raise ValueError(f"Could not find messages in sample: {list(sample.keys())}") - - # Encode messages using template - encoded = self.template.encode_messages(self.tokenizer, messages) - - # Convert to tensors - return { - "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), - "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), - "labels": torch.tensor(encoded["labels"], dtype=torch.long), - } - - -def create_real_dataset(max_samples: int = 20, batch_size: int = 4): - """Create a real dataset using DataEngine.""" - data_args = DataArguments(dataset="llamafactory/v1-sft-demo") - data_engine = DataEngine(data_args) - - # Create processor and template - processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5") - template = QwenTemplate() - - # Create tensor dataset - raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples) - - # Create torch DataLoader - torch_dataloader = TorchDataLoader( - raw_data_dataset, - batch_size=batch_size, - shuffle=False, - collate_fn=lambda x: x, - ) - - return torch_dataloader, processor, template - - -class TestDataLoaderNonPackNonDynamic: - """Test case a) non pack + non dynamic.""" - - def test_basic_functionality(self): - """Test DataLoader without packing and without dynamic batching.""" - # Create real dataset - torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8) - - # Create collator (non-packing) - collator = DefaultCollator(processor=processor, template=template) - - # Create DataLoader without batching_queue (non-dynamic) - data_loader = DataLoader( - dataloader=torch_dataloader, - collate_fn=collator, - num_micro_batch=1, - batching_queue=None, - ) - - # Iterate and check results - batches = list(iter(data_loader)) - assert len(batches) > 0 - - # Check first batch - one_batch = batches[0] - micro_batches = one_batch[0] - assert "input_ids" in micro_batches - assert "attention_mask" in micro_batches - assert "labels" in micro_batches - assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1 - assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len] - - -class TestDataLoaderNonPackDynamic: - """Test case b) non pack + dynamic.""" - - def test_basic_functionality(self): - """Test DataLoader without packing but with dynamic batching.""" - # Create real dataset - torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8) - collator = DefaultCollator(processor=processor, template=template) - - # Create batching queue for dynamic batching - batching_queue = TextBatchingQueue( - token_micro_bsz=120, - buffer_size=8, - ) - - data_loader = DataLoader( - dataloader=torch_dataloader, - collate_fn=collator, - num_micro_batch=4, - batching_queue=batching_queue, - ) - - # Iterate and check - batches = list(iter(data_loader)) - micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]] - assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first) - assert len(batches) > 0 diff --git a/tests_v1/core/test_model_loader.py b/tests_v1/core/test_model_loader.py index fa038229e..5f0150b63 100644 --- a/tests_v1/core/test_model_loader.py +++ b/tests_v1/core/test_model_loader.py @@ -15,18 +15,18 @@ import torch from llamafactory.v1.config.model_args import ModelArguments, PluginConfig -from llamafactory.v1.core.model_loader import ModelLoader +from llamafactory.v1.core.model_engine import ModelEngine def test_tiny_qwen(): from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5") - model_loader = ModelLoader(model_args) - assert isinstance(model_loader.processor, Qwen2TokenizerFast) - assert isinstance(model_loader.model.config, Qwen2Config) - assert isinstance(model_loader.model, Qwen2ForCausalLM) - assert model_loader.model.dtype == torch.bfloat16 + model_engine = ModelEngine(model_args) + assert isinstance(model_engine.processor, Qwen2TokenizerFast) + assert isinstance(model_engine.model_config, Qwen2Config) + assert isinstance(model_engine.model, Qwen2ForCausalLM) + assert model_engine.model.dtype == torch.bfloat16 def test_tiny_qwen_with_kernel_plugin(): @@ -37,13 +37,14 @@ def test_tiny_qwen_with_kernel_plugin(): model_args = ModelArguments( model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto") ) - model_loader = ModelLoader(model_args) + model_engine = ModelEngine(model_args) # test enable apply kernel plugin if hasattr(torch, "npu"): - assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__ + assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__ else: - assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__ - assert isinstance(model_loader.model, Qwen2ForCausalLM) + assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__ + + assert isinstance(model_engine.model, Qwen2ForCausalLM) if __name__ == "__main__": diff --git a/tests_v1/core/utils/test_data_loader.py b/tests_v1/core/utils/test_data_loader.py new file mode 100644 index 000000000..cbddf0887 --- /dev/null +++ b/tests_v1/core/utils/test_data_loader.py @@ -0,0 +1,171 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for DataLoader with different combinations of packing and dynamic batching. + +Tests the 4 scenarios: +a) non pack + non dynamic. +b) non pack + dynamic. +c) pack + non dynamic. +d) pack + dynamic. +""" + +# import torch +# from torch.utils.data import DataLoader as TorchDataLoader +# from torch.utils.data import Dataset +# from transformers import AutoTokenizer + +# from llamafactory.v1.config.data_args import DataArguments +# from llamafactory.v1.core.data_engine import DataEngine +# from llamafactory.v1.core.utils.data_collator import DefaultCollator +# from llamafactory.v1.core.utils.data_loader import DataLoader +# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate +# from llamafactory.v1.utils.batching_queue import TextBatchingQueue + + +# class TensorDataset(Dataset): +# """Wrapper dataset that converts DataEngine samples to tensor format.""" + +# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None): +# self.data_engine = data_engine +# self.processor = processor +# self.template = template +# self.max_samples = max_samples or len(data_engine) +# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor + +# def __len__(self): +# return min(self.max_samples, len(self.data_engine)) + +# def __getitem__(self, idx): +# # Get sample from DataEngine +# sample = self.data_engine[idx] + +# # Extract messages from sample +# # DataEngine returns samples with format like {"messages": [...], ...} +# # For llamafactory/v1-sft-demo, the format should have "messages" field +# messages = None +# if "messages" in sample: +# messages = sample["messages"] +# elif "conversations" in sample: +# messages = sample["conversations"] +# elif "conversation" in sample: +# messages = sample["conversation"] +# else: +# # Try to find message-like fields (skip _dataset_name) +# for key, value in sample.items(): +# if key.startswith("_"): +# continue +# if isinstance(value, list) and len(value) > 0: +# # Check if it looks like a message list +# if isinstance(value[0], dict) and "role" in value[0]: +# messages = value +# break + +# if messages is None: +# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}") + +# # Encode messages using template +# encoded = self.template.encode_messages(self.tokenizer, messages) + +# # Convert to tensors +# return { +# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long), +# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long), +# "labels": torch.tensor(encoded["labels"], dtype=torch.long), +# } + + +# def create_real_dataset(max_samples: int = 20, batch_size: int = 4): +# """Create a real dataset using DataEngine.""" +# data_args = DataArguments(dataset="llamafactory/v1-sft-demo") +# data_engine = DataEngine(data_args) + +# # Create processor and template +# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5") +# template = QwenTemplate() + +# # Create tensor dataset +# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples) + +# # Create torch DataLoader +# torch_dataloader = TorchDataLoader( +# raw_data_dataset, +# batch_size=batch_size, +# shuffle=False, +# collate_fn=lambda x: x, +# ) + +# return torch_dataloader, processor, template + + +# class TestDataLoaderNonPackNonDynamic: +# """Test case a) non pack + non dynamic.""" + +# def test_basic_functionality(self): +# """Test DataLoader without packing and without dynamic batching.""" +# # Create real dataset +# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8) + +# # Create collator (non-packing) +# collator = DefaultCollator(processor=processor, template=template) + +# # Create DataLoader without batching_queue (non-dynamic) +# data_loader = DataLoader( +# dataloader=torch_dataloader, +# collate_fn=collator, +# num_micro_batch=1, +# batching_queue=None, +# ) + +# # Iterate and check results +# batches = list(iter(data_loader)) +# assert len(batches) > 0 + +# # Check first batch +# one_batch = batches[0] +# micro_batches = one_batch[0] +# assert "input_ids" in micro_batches +# assert "attention_mask" in micro_batches +# assert "labels" in micro_batches +# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1 +# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len] + + +# class TestDataLoaderNonPackDynamic: +# """Test case b) non pack + dynamic.""" + +# def test_basic_functionality(self): +# """Test DataLoader without packing but with dynamic batching.""" +# # Create real dataset +# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8) +# collator = DefaultCollator(processor=processor, template=template) + +# # Create batching queue for dynamic batching +# batching_queue = TextBatchingQueue( +# token_micro_bsz=120, +# buffer_size=8, +# ) + +# data_loader = DataLoader( +# dataloader=torch_dataloader, +# collate_fn=collator, +# num_micro_batch=4, +# batching_queue=batching_queue, +# ) + +# # Iterate and check +# batches = list(iter(data_loader)) +# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]] +# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first) +# assert len(batches) > 0 diff --git a/tests_v1/core/utils/test_rendering.py b/tests_v1/core/utils/test_rendering.py new file mode 100644 index 000000000..bca3d00c9 --- /dev/null +++ b/tests_v1/core/utils/test_rendering.py @@ -0,0 +1,65 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import AutoTokenizer + +from llamafactory.v1.core.utils.rendering import Renderer +from llamafactory.v1.utils.types import Processor + + +HF_MESSAGES = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is LLM?"}, + {"role": "assistant", "content": "LLM stands for Large Language Model."}, +] +V1_MESSAGES = [ + {"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]}, + {"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]}, + {"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]}, +] + + +def test_chatml_rendering(): + tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") + renderer = Renderer(template="chatml", processor=tokenizer) + + hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True) + v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True) + assert v1_inputs["input_ids"] == hf_inputs + assert v1_inputs["attention_mask"] == [1] * len(hf_inputs) + assert v1_inputs["labels"] == [-100] * len(hf_inputs) + assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs) + + hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False) + hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False) + v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False) + assert v1_inputs_full["input_ids"] == hf_inputs_full + assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full) + assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :] + assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * ( + len(hf_inputs_full) - len(hf_inputs_part) + ) + + +def test_chatml_parse(): + tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") + renderer = Renderer(template="chatml", processor=tokenizer) + generated_text = "LLM stands for Large Language Model." + parsed_message = renderer.parse_message(generated_text) + assert parsed_message == V1_MESSAGES[-1] + + +if __name__ == "__main__": + test_chatml_rendering() + test_chatml_parse() diff --git a/tests_v1/plugins/data_plugins/test_converter.py b/tests_v1/plugins/data_plugins/test_converter.py index 0ad40aa19..2a12cf55d 100644 --- a/tests_v1/plugins/data_plugins/test_converter.py +++ b/tests_v1/plugins/data_plugins/test_converter.py @@ -54,7 +54,7 @@ def test_sharegpt_converter(): "conversations": [ {"from": "system", "value": "System"}, {"from": "human", "value": "User"}, - {"from": "function_call", "value": "Tool"}, + {"from": "function_call", "value": "1"}, {"from": "observation", "value": "Observation"}, {"from": "gpt", "value": "Assistant"}, ] @@ -63,7 +63,7 @@ def test_sharegpt_converter(): "messages": [ {"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"}, {"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"}, - {"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"}, + {"content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0, "role": "assistant"}, {"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"}, {"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"}, ] diff --git a/tests_v1/plugins/model_plugins/test_init_plugin.py b/tests_v1/plugins/model_plugins/test_init_plugin.py index 80e9d178b..d13294c8e 100644 --- a/tests_v1/plugins/model_plugins/test_init_plugin.py +++ b/tests_v1/plugins/model_plugins/test_init_plugin.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from llamafactory.v1.accelerator.interface import DistributedInterface from llamafactory.v1.config.arg_parser import get_args -from llamafactory.v1.core.model_loader import ModelLoader +from llamafactory.v1.core.model_engine import ModelEngine def test_init_on_meta(): @@ -26,11 +25,10 @@ def test_init_on_meta(): init_config={"name": "init_on_meta"}, ) ) - model_loader = ModelLoader(model_args=model_args) - assert model_loader.model.device.type == "meta" + model_engine = ModelEngine(model_args=model_args) + assert model_engine.model.device.type == "meta" -@pytest.mark.runs_on(["cuda", "npu"]) def test_init_on_rank0(): _, model_args, *_ = get_args( dict( @@ -38,11 +36,11 @@ def test_init_on_rank0(): init_config={"name": "init_on_rank0"}, ) ) - model_loader = ModelLoader(model_args=model_args) + model_engine = ModelEngine(model_args=model_args) if DistributedInterface().get_rank() == 0: - assert model_loader.model.device.type == "cpu" + assert model_engine.model.device.type == "cpu" else: - assert model_loader.model.device.type == "meta" + assert model_engine.model.device.type == "meta" def test_init_on_default(): @@ -52,5 +50,5 @@ def test_init_on_default(): init_config={"name": "init_on_default"}, ) ) - model_loader = ModelLoader(model_args=model_args) - assert model_loader.model.device.type == DistributedInterface().current_accelerator.type + model_engine = ModelEngine(model_args=model_args) + assert model_engine.model.device == DistributedInterface().current_device diff --git a/tests_v1/sampler/test_cli_sampler.py b/tests_v1/sampler/test_cli_sampler.py new file mode 100644 index 000000000..cd20432a3 --- /dev/null +++ b/tests_v1/sampler/test_cli_sampler.py @@ -0,0 +1,41 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from llamafactory.v1.config import ModelArguments, SampleArguments +from llamafactory.v1.core.model_engine import ModelEngine +from llamafactory.v1.samplers.cli_sampler import SyncSampler + + +@pytest.mark.runs_on(["cuda", "npu"]) +def test_sync_sampler(): + model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507") + sample_args = SampleArguments() + model_engine = ModelEngine(model_args) + sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer) + messages = [{"role": "user", "content": [{"type": "text", "value": "Say 'This is a test.'"}]}] + response = "" + for new_text in sampler.generate(messages): + response += new_text + + print(response) + assert model_engine.renderer.parse_message(response) == { + "role": "assistant", + "content": [{"type": "text", "value": "This is a test."}], + } + + +if __name__ == "__main__": + test_sync_sampler()