diff --git a/.github/workflows/tests_cuda.yml b/.github/workflows/tests_cuda.yml
index 1c51e1c7a..53548f509 100644
--- a/.github/workflows/tests_cuda.yml
+++ b/.github/workflows/tests_cuda.yml
@@ -55,12 +55,12 @@ jobs:
uv pip install -e .
uv pip install -r requirements/dev.txt
- - name: Cache HuggingFace models
+ - name: Cache files
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
- key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
+ key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality
run: |
diff --git a/pyproject.toml b/pyproject.toml
index 357df9314..0faa5d9e0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -73,7 +73,7 @@ dependencies = [
# api
"uvicorn",
"fastapi",
- "sse-starlette"
+ "sse-starlette",
]
[project.scripts]
diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py
index 04ded900c..c9cee520c 100644
--- a/src/llamafactory/v1/accelerator/helper.py
+++ b/src/llamafactory/v1/accelerator/helper.py
@@ -119,9 +119,19 @@ def synchronize() -> None:
@requires_accelerator
-def set_device() -> None:
- """Set current accelerator."""
- torch.accelerator.set_device_index(get_local_rank())
+def set_device_index() -> None:
+ """Set current accelerator index to local rank."""
+ if get_current_accelerator().type != DeviceType.CPU:
+ torch.accelerator.set_device_index(get_local_rank())
+
+
+@requires_accelerator
+def get_current_device() -> torch.device:
+ """Get current accelerator device."""
+ if get_current_accelerator().type == DeviceType.CPU:
+ return torch.device(DeviceType.CPU.value)
+ else:
+ return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
def is_torch_cuda_available():
diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py
index c7cdf4f5f..2837cbc74 100644
--- a/src/llamafactory/v1/accelerator/interface.py
+++ b/src/llamafactory/v1/accelerator/interface.py
@@ -123,12 +123,13 @@ class DistributedInterface:
if self._initialized:
return
+ helper.set_device_index()
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
- self.current_accelerator = helper.get_current_accelerator()
+ self.current_device = helper.get_current_device()
self.device_count = helper.get_device_count()
if config is None:
@@ -144,15 +145,14 @@ class DistributedInterface:
timeout = config.get("timeout", 18000)
if self._is_distributed:
- helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
- device_type=self.current_accelerator.type,
+ device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
- device_type=self.current_accelerator.type,
+ device_type=self.current_device.type,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
@@ -161,12 +161,12 @@ class DistributedInterface:
self.data_device_mesh = None
self._initialized = True
- logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
+ logger.info_rank0(f"DistributedInterface initialized: {self}.")
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
- f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
+ f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
@@ -251,4 +251,7 @@ class DistributedInterface:
if __name__ == "__main__":
- print(DistributedInterface(DistributedStrategy()))
+ """
+ python -m llamafactory.v1.accelerator.interface
+ """
+ print(DistributedInterface())
diff --git a/src/llamafactory/v1/config/arg_utils.py b/src/llamafactory/v1/config/arg_utils.py
index 5335cdbb5..5673bd687 100644
--- a/src/llamafactory/v1/config/arg_utils.py
+++ b/src/llamafactory/v1/config/arg_utils.py
@@ -17,7 +17,7 @@
import json
-from enum import Enum, unique
+from enum import StrEnum, unique
class PluginConfig(dict):
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
@unique
-class ModelClass(str, Enum):
+class ModelClass(StrEnum):
"""Auto class for model config."""
LLM = "llm"
@@ -45,7 +45,7 @@ class ModelClass(str, Enum):
@unique
-class SampleBackend(str, Enum):
+class SampleBackend(StrEnum):
HF = "hf"
VLLM = "vllm"
diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py
index b79ee86de..d017cd149 100644
--- a/src/llamafactory/v1/config/model_args.py
+++ b/src/llamafactory/v1/config/model_args.py
@@ -21,8 +21,13 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass
class ModelArguments:
model: str = field(
+ default="Qwen/Qwen3-4B-Instruct-2507",
metadata={"help": "Path to the model or model identifier from Hugging Face."},
)
+ template: str = field(
+ default="chatml",
+ metadata={"help": "Template for the model."},
+ )
trust_remote_code: bool = field(
default=False,
metadata={"help": "Trust remote code from Hugging Face."},
diff --git a/src/llamafactory/v1/core/base_sampler.py b/src/llamafactory/v1/core/base_sampler.py
index fbd36ab69..aaf376a9f 100644
--- a/src/llamafactory/v1/core/base_sampler.py
+++ b/src/llamafactory/v1/core/base_sampler.py
@@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
+import os
from abc import ABC, abstractmethod
+from collections.abc import AsyncGenerator, Generator
+from threading import Thread
+import torch
+from transformers import TextIteratorStreamer
+
+from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend
-from ..utils.types import HFModel, Processor, TorchDataset
+from ..utils.helper import get_tokenizer
+from ..utils.types import HFModel, Message, Sample, TorchDataset
+from .utils.rendering import Renderer
class BaseEngine(ABC):
@@ -24,8 +34,8 @@ class BaseEngine(ABC):
self,
args: SampleArguments,
model_args: ModelArguments,
- model: HFModel = None,
- processor: Processor = None,
+ model: HFModel,
+ renderer: Renderer,
) -> None:
"""Initialize the engine.
@@ -33,17 +43,34 @@ class BaseEngine(ABC):
args: Sample arguments.
model_args: Model arguments.
model: Model.
- processor: Processor.
+ renderer: Renderer.
"""
...
@abstractmethod
- async def generate(self, messages):
- pass
+ async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
+ """Generate tokens asynchronously.
+
+ Args:
+ messages: List of messages.
+ tools: Tools string.
+
+ Yields:
+ Generated tokens.
+ """
+ ...
@abstractmethod
- async def batch_infer(self, data: TorchDataset) -> None:
- pass
+ async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
+ """Batch infer samples.
+
+ Args:
+ dataset: Torch dataset.
+
+ Returns:
+ List of samples.
+ """
+ ...
class HuggingFaceEngine(BaseEngine):
@@ -52,26 +79,103 @@ class HuggingFaceEngine(BaseEngine):
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
- processor: Processor,
+ renderer: Renderer,
) -> None:
self.args = args
+ self.model_args = model_args
+ self.model = model
+ self.renderer = renderer
+ self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
+
+ @torch.inference_mode()
+ def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
+ model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
+ streamer = TextIteratorStreamer(
+ tokenizer=get_tokenizer(self.renderer.processor),
+ skip_prompt=True,
+ skip_special_tokens=True, # TODO: configurable
+ )
+ device = DistributedInterface().current_device
+ kwargs = {
+ "input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
+ "attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
+ "max_new_tokens": self.args.max_new_tokens,
+ "streamer": streamer,
+ }
+ thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
+ thread.start()
+
+ def stream():
+ try:
+ return streamer.__next__()
+ except StopIteration:
+ raise StopAsyncIteration()
+
+ return stream
+
+ async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
+ async with self.semaphore:
+ response = self.get_response(messages, tools)
+ while True:
+ try:
+ yield await asyncio.to_thread(response)
+ except StopAsyncIteration:
+ break
+
+ async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
+ """Batch infer samples.
+
+ Args:
+ dataset: Torch dataset.
+
+ Returns:
+ List of samples.
+ """
+ raise NotImplementedError("Batch infer is not implemented.")
class BaseSampler:
+ """Base sampler.
+
+ Args:
+ args: Sample arguments.
+ model_args: Model arguments.
+ model: Model.
+ renderer: Renderer.
+ """
+
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
- processor: Processor,
+ renderer: Renderer,
) -> None:
if args.sample_backend == SampleBackend.HF:
- self.engine = HuggingFaceEngine(args, model_args, model, processor)
+ self.engine = HuggingFaceEngine(args, model_args, model, renderer)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
- async def generate(self, messages):
- return await self.engine.generate(messages)
+ async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
+ """Generate tokens asynchronously.
- async def batch_infer(self, data: TorchDataset) -> None:
- return await self.engine.batch_infer(data)
+ Args:
+ messages: List of messages.
+ tools: Tools string.
+
+ Yields:
+ Generated tokens.
+ """
+ async for token in self.engine.generate(messages, tools):
+ yield token
+
+ async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
+ """Batch infer samples.
+
+ Args:
+ dataset: Torch dataset.
+
+ Returns:
+ List of samples.
+ """
+ return await self.engine.batch_infer(dataset)
diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py
index f0ebb00af..18f14f1b6 100644
--- a/src/llamafactory/v1/core/data_engine.py
+++ b/src/llamafactory/v1/core/data_engine.py
@@ -14,15 +14,23 @@
"""The definition of data engine.
-Init Data engine:
+How to use:
+data_engine = DataEngine(data_args)
+data_engine[i]: Get the sample via index.
+
+Init workflow:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
-Get Data Sample:
+Get data sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
+
+Note:
+1. The data engine is equivalent to the torch dataset.
+2. The data engine is agnostic to the model used.
"""
import os
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
- if size or weight: # data index plugin
- from ..plugins.data_plugins.loader import DataIndexPlugin
+ if size or weight:
+ from ..plugins.data_plugins.loader import adjust_data_index
- data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
+ data_index = adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else: # data selector plugin
- from ..plugins.data_plugins.loader import DataSelectorPlugin
+ from ..plugins.data_plugins.loader import select_data_sample
- selected_index = DataSelectorPlugin().select(self.data_index, index)
+ selected_index = select_data_sample(self.data_index, index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
diff --git a/src/llamafactory/v1/core/model_loader.py b/src/llamafactory/v1/core/model_engine.py
similarity index 85%
rename from src/llamafactory/v1/core/model_loader.py
rename to src/llamafactory/v1/core/model_engine.py
index 069292274..f69b4c203 100644
--- a/src/llamafactory/v1/core/model_loader.py
+++ b/src/llamafactory/v1/core/model_engine.py
@@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-"""The definition of model loader.
+"""The definition of model engine.
How to use:
-model_loader = ModelLoader(model_args, is_trainable=True)
-model_loader.processor: Get the tokenizer or multi-modal processor.
-model_loader.model_config: Get the model configuration.
-model_loader.model: Get the HF model.
+model_engine = ModelEngine(model_args, is_train=True)
+model_engine.processor: Get the tokenizer or multi-modal processor.
+model_engine.renderer: Get the renderer.
+model_engine.model_config: Get the model configuration.
+model_engine.model: Get the HF model.
-Init Workflow:
+Init workflow:
1. Init processor.
+2. Init render.
2. Init model config.
3. Init model.
4. Init adapter.
@@ -36,17 +38,18 @@ from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
+from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
-class ModelLoader:
- """Model loader.
+class ModelEngine:
+ """Model engine.
Args:
model_args: Model arguments.
- is_trainable: Whether to train the model.
+ is_train: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
@@ -56,6 +59,8 @@ class ModelLoader:
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
+ self.renderer = Renderer(self.args.template, self.processor)
+ """Renderer."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
@@ -107,7 +112,7 @@ class ModelLoader:
init_device = InitPlugin(self.args.init_config.name)()
else:
- init_device = DistributedInterface().current_accelerator
+ init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META:
with init_empty_weights():
@@ -144,12 +149,12 @@ class ModelLoader:
if __name__ == "__main__":
"""
- python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
+ python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
- model_loader = ModelLoader(model_args=model_args)
- print(model_loader.processor)
- print(model_loader.model_config)
- print(model_loader.model)
+ model_engine = ModelEngine(model_args=model_args)
+ print(model_engine.processor)
+ print(model_engine.model_config)
+ print(model_engine.model)
diff --git a/src/llamafactory/v1/core/trainer_utils/__init__.py b/src/llamafactory/v1/core/utils/__init__.py
similarity index 100%
rename from src/llamafactory/v1/core/trainer_utils/__init__.py
rename to src/llamafactory/v1/core/utils/__init__.py
diff --git a/src/llamafactory/v1/core/trainer_utils/callback.py b/src/llamafactory/v1/core/utils/callback.py
similarity index 100%
rename from src/llamafactory/v1/core/trainer_utils/callback.py
rename to src/llamafactory/v1/core/utils/callback.py
diff --git a/src/llamafactory/v1/core/trainer_utils/data_collator.py b/src/llamafactory/v1/core/utils/data_collator.py
similarity index 100%
rename from src/llamafactory/v1/core/trainer_utils/data_collator.py
rename to src/llamafactory/v1/core/utils/data_collator.py
diff --git a/src/llamafactory/v1/core/trainer_utils/data_loader.py b/src/llamafactory/v1/core/utils/data_loader.py
similarity index 100%
rename from src/llamafactory/v1/core/trainer_utils/data_loader.py
rename to src/llamafactory/v1/core/utils/data_loader.py
diff --git a/src/llamafactory/v1/core/trainer_utils/lr_scheduler.py b/src/llamafactory/v1/core/utils/lr_scheduler.py
similarity index 100%
rename from src/llamafactory/v1/core/trainer_utils/lr_scheduler.py
rename to src/llamafactory/v1/core/utils/lr_scheduler.py
diff --git a/src/llamafactory/v1/core/utils/rendering.py b/src/llamafactory/v1/core/utils/rendering.py
new file mode 100644
index 000000000..20460c2d5
--- /dev/null
+++ b/src/llamafactory/v1/core/utils/rendering.py
@@ -0,0 +1,239 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import re
+
+from ...utils.constants import IGNORE_INDEX
+from ...utils.helper import get_tokenizer
+from ...utils.types import Message, ModelInput, Processor
+
+
+def _update_model_input(
+ processor: Processor,
+ input_ids: list[int],
+ labels: list[int],
+ loss_weights: list[int],
+ temp_str: str,
+ temp_weight: float,
+) -> str:
+ """Update model input with temporary string."""
+ if not temp_str:
+ return ""
+
+ tokenizer = get_tokenizer(processor)
+ temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
+ input_ids.extend(temp_ids)
+ loss_weights.extend([temp_weight] * len(temp_ids))
+ if temp_weight > 1e-6:
+ labels.extend(temp_ids)
+ else:
+ labels.extend([IGNORE_INDEX] * len(temp_ids))
+
+ return ""
+
+
+def render_chatml_messages(
+ processor: Processor,
+ messages: list[Message],
+ tools: str | None = None,
+ is_generate: bool = False,
+) -> ModelInput:
+ """Apply chatml template to messages and convert them to model input.
+
+ See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
+ """
+ input_ids, labels, loss_weights = [], [], []
+ temp_str, temp_weight = "", 0.0
+ if tools:
+ temp_str += "<|im_start|>system\n"
+ if messages[0]["role"] == "system":
+ for content in messages[0]["content"]:
+ if content["type"] == "text":
+ temp_str += content["value"]
+ else:
+ raise ValueError(f"Unsupported content type: {content['type']}")
+
+ temp_str += "\n\n"
+ temp_weight = messages[0].get("loss_weight", 0.0)
+
+ temp_str += (
+ "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
+ "You are provided with function signatures within XML tags:\n"
+ )
+ try:
+ tools = json.loads(tools)
+ except json.JSONDecodeError:
+ raise ValueError(f"Invalid tools format: {str(tools)}.")
+
+ if not isinstance(tools, list):
+ tools = [tools]
+
+ for tool in tools:
+ temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
+
+ temp_str += (
+ "\n\n\nFor each function call, return a json object with function name "
+ 'and arguments within XML tags:\n\n{"name": '
+ ', "arguments": }\n<|im_end|>\n'
+ )
+ elif messages[0]["role"] == "system":
+ temp_str += "<|im_start|>system\n"
+ for content in messages[0]["content"]:
+ if content["type"] == "text":
+ temp_str += content["value"]
+ else:
+ raise ValueError(f"Unsupported content type: {content['type']}")
+
+ temp_str += "<|im_end|>\n"
+ temp_weight = messages[0].get("loss_weight", 0.0)
+
+ temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
+
+ for turn_idx, message in enumerate(messages):
+ if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
+ temp_str += "<|im_start|>" + message["role"] + "\n"
+ for content in message["content"]:
+ if content["type"] == "text":
+ temp_str += content["value"]
+ else:
+ raise ValueError(f"Unsupported content type: {content['type']}")
+
+ temp_str += "<|im_end|>\n"
+ temp_weight = message.get("loss_weight", 0.0)
+ elif message["role"] == "assistant":
+ temp_str += "<|im_start|>" + message["role"] + "\n"
+ for val_idx, content in enumerate(message["content"]):
+ if content["type"] == "text":
+ temp_str += content["value"]
+ elif content["type"] == "reasoning":
+ temp_str += "\n" + content["value"] + "\n\n\n" # avoid using special tokens
+ elif content["type"] == "tool_call":
+ if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
+ temp_str += "\n"
+
+ try:
+ tool_call = json.loads(content["value"])
+ except json.JSONDecodeError:
+ raise ValueError(f"Invalid tool call format: {content['value']}.")
+ temp_str += (
+ '\n{"name": "'
+ + tool_call["name"]
+ + '", "arguments": '
+ + json.dumps(tool_call["arguments"], ensure_ascii=False)
+ + "}\n"
+ )
+
+ else:
+ raise ValueError(f"Unsupported content type: {content['type']}")
+
+ temp_str += "<|im_end|>\n"
+ temp_weight = message.get("loss_weight", 1.0)
+ elif message["role"] == "tool":
+ if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
+ temp_str += "<|im_start|>user"
+
+ temp_str += "\n\n"
+ for content in message["content"]:
+ if content["type"] == "text":
+ temp_str += content["value"]
+ else:
+ raise ValueError(f"Unsupported content type: {content['type']}")
+
+ temp_str += "\n"
+ if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
+ temp_str += "<|im_end|>\n"
+
+ temp_weight = message.get("loss_weight", 0.0)
+
+ temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
+
+ if is_generate:
+ temp_str += "<|im_start|>assistant\n"
+ temp_weight = 0.0
+
+ temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
+
+ attention_mask = [1] * len(input_ids)
+ return ModelInput(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ labels=labels,
+ loss_weights=loss_weights,
+ )
+
+
+def parse_chatml_message(generated_text: str) -> Message:
+ """Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
+
+ Args:
+ generated_text (str): The generated text in ChatML format.
+
+ Returns:
+ Message: The parsed message.
+ """
+ pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*\1>\s*", re.DOTALL)
+ content = []
+ last_end = 0
+ for match in pattern.finditer(generated_text):
+ start, end = match.span()
+ if start > last_end:
+ text = generated_text[last_end:start].strip()
+ if text:
+ content.append({"type": "text", "value": text})
+
+ tag_type = match.group(1)
+ tag_value = match.group(2).strip()
+ if tag_type == "thinking":
+ content.append({"type": "reasoning", "value": tag_value.strip()})
+ elif tag_type == "tool_call":
+ try:
+ json.loads(tag_value.strip())
+ except json.JSONDecodeError:
+ raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
+
+ content.append({"type": "tool_call", "value": tag_value.strip()})
+
+ last_end = end
+
+ if last_end < len(generated_text):
+ text = generated_text[last_end:].strip()
+ if text:
+ content.append({"type": "text", "value": text})
+
+ return Message(role="assistant", content=content)
+
+
+class Renderer:
+ def __init__(self, template: str, processor: Processor):
+ self.template = template
+ self.processor = processor
+
+ def render_messages(
+ self, messages: list[Message], tools: str | None = None, is_generate: bool = False
+ ) -> ModelInput:
+ if self.template == "chatml":
+ return render_chatml_messages(self.processor, messages, tools, is_generate)
+ else:
+ from ...plugins.model_plugins.rendering import RenderingPlugin
+
+ return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
+
+ def parse_message(self, generated_text: str) -> Message:
+ if self.template == "chatml":
+ return parse_chatml_message(generated_text)
+ else:
+ from ...plugins.model_plugins.rendering import RenderingPlugin
+
+ return RenderingPlugin(self.template).parse_message(generated_text)
diff --git a/src/llamafactory/v1/launcher.py b/src/llamafactory/v1/launcher.py
index 3d18286bc..9f04ab398 100644
--- a/src/llamafactory/v1/launcher.py
+++ b/src/llamafactory/v1/launcher.py
@@ -49,6 +49,11 @@ def launch():
run_sft()
+ elif command == "chat":
+ from .samplers.cli_sampler import run_chat
+
+ run_chat()
+
elif command == "env":
print_env()
diff --git a/src/llamafactory/v1/plugins/data_plugins/converter.py b/src/llamafactory/v1/plugins/data_plugins/converter.py
index 07aae2dfe..ca7b4fa98 100644
--- a/src/llamafactory/v1/plugins/data_plugins/converter.py
+++ b/src/llamafactory/v1/plugins/data_plugins/converter.py
@@ -13,11 +13,12 @@
# limitations under the License.
+import json
from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging
from ...utils.plugin import BasePlugin
-from ...utils.types import DPOSample, Sample, SFTSample
+from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
logger = logging.get_logger(__name__)
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
return super().__call__(raw_sample)
-@DataConverterPlugin("alpaca").register
+@DataConverterPlugin("alpaca").register()
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample.
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages}
-@DataConverterPlugin("sharegpt").register
+@DataConverterPlugin("sharegpt").register()
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"function_call": "assistant",
}
messages = []
- tools = raw_sample.get("tools", "")
+ tools = raw_sample.get("tools")
+ if tools:
+ try:
+ tools: list[dict[str, Any]] = json.loads(tools)
+ except json.JSONDecodeError:
+ logger.warning_rank0(f"Invalid tools format: {str(tools)}")
+ tools = []
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
+ try:
+ tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
+ except json.JSONDecodeError:
+ logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
+ continue
+
+ if not isinstance(tool_calls, list):
+ tool_calls = [tool_calls]
+
messages.append(
{
"role": "assistant",
- "content": [{"type": "tool_calls", "value": message["value"]}],
+ "content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0,
}
)
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
)
if tools:
- if messages and messages[0]["role"] == "system":
- messages[0]["content"].append({"type": "tools", "value": tools})
- else:
- messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
-
- return {"messages": messages}
+ return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
+ else:
+ return {"messages": messages}
-@DataConverterPlugin("pair").register
+@DataConverterPlugin("pair").register()
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to DPO sample.
diff --git a/src/llamafactory/v1/plugins/data_plugins/loader.py b/src/llamafactory/v1/plugins/data_plugins/loader.py
index 9200ec34a..eb24ca614 100644
--- a/src/llamafactory/v1/plugins/data_plugins/loader.py
+++ b/src/llamafactory/v1/plugins/data_plugins/loader.py
@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
raise ValueError(f"Unknown dataset filetype: {filetype}.")
-@DataLoaderPlugin("local").register
+@DataLoaderPlugin("local").register()
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0])
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
return dataset
-class DataIndexPlugin(BasePlugin):
- """Plugin for adjusting dataset index."""
+def adjust_data_index(
+ data_index: list[tuple[str, int]], size: int | None, weight: float | None
+) -> list[tuple[str, int]]:
+ """Adjust dataset index by size and weight.
- def adjust_data_index(
- self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
- ) -> list[tuple[str, int]]:
- """Adjust dataset index by size and weight.
+ Args:
+ data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
+ size (Optional[int]): Desired dataset size.
+ weight (Optional[float]): Desired dataset weight.
- Args:
- data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
- size (Optional[int]): Desired dataset size.
- weight (Optional[float]): Desired dataset weight.
+ Returns:
+ list[tuple[str, int]]: Adjusted dataset index.
+ """
+ if size is not None:
+ data_index = random.choices(data_index, k=size)
- Returns:
- list[tuple[str, int]]: Adjusted dataset index.
- """
- if size is not None:
- data_index = random.choices(data_index, k=size)
+ if weight is not None:
+ data_index = random.choices(data_index, k=int(len(data_index) * weight))
- if weight is not None:
- data_index = random.choices(data_index, k=int(len(data_index) * weight))
-
- return data_index
+ return data_index
-class DataSelectorPlugin(BasePlugin):
- """Plugin for selecting dataset samples."""
+def select_data_sample(
+ data_index: list[tuple[str, int]], index: slice | list[int] | Any
+) -> tuple[str, int] | list[tuple[str, int]]:
+ """Select dataset samples.
- def select(
- self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
- ) -> tuple[str, int] | list[tuple[str, int]]:
- """Select dataset samples.
+ Args:
+ data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
+ index (Union[slice, list[int], Any]): Index of dataset samples.
- Args:
- data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
- index (Union[slice, list[int], Any]): Index of dataset samples.
-
- Returns:
- Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
- """
- if isinstance(index, slice):
- return [data_index[i] for i in range(*index.indices(len(data_index)))]
- elif isinstance(index, list):
- return [data_index[i] for i in index]
- else:
- raise ValueError(f"Invalid index type {type(index)}.")
+ Returns:
+ Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
+ """
+ if isinstance(index, slice):
+ return [data_index[i] for i in range(*index.indices(len(data_index)))]
+ elif isinstance(index, list):
+ return [data_index[i] for i in index]
+ else:
+ raise ValueError(f"Invalid index type {type(index)}.")
diff --git a/src/llamafactory/v1/plugins/data_plugins/template.py b/src/llamafactory/v1/plugins/data_plugins/template.py
deleted file mode 100644
index 96159142e..000000000
--- a/src/llamafactory/v1/plugins/data_plugins/template.py
+++ /dev/null
@@ -1,133 +0,0 @@
-# Copyright 2025 the LlamaFactory team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from dataclasses import dataclass
-
-
-@dataclass
-class Template:
- user_template: str
- assistant_template: str
- system_template: str
-
- def render_message(self, message: dict[str, str]) -> str:
- return self.user_template.format(**message)
-
-
-@dataclass
-class QwenTemplate:
- message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
- thinking_template: str = "\n{content}\n\n\n"
-
- def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
- if isinstance(content_data, str):
- return content_data.strip()
-
- if isinstance(content_data, list):
- parts = []
- for item in content_data:
- if item.get("type") == "text":
- parts.append(item.get("value", ""))
- elif item.get("type") == "image_url":
- pass
- return "\n".join(parts).strip()
-
- return ""
-
- def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
- role = message["role"]
- content = self._extract_content(message.get("content", ""))
-
- if role == "assistant":
- reasoning_content = message.get("reasoning_content", "")
- if reasoning_content:
- reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
- return self.message_template.format(role="assistant", content=reasoning_content + content)
- else:
- return self.message_template.format(role=role, content=content)
-
- def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
- """Encode one message."""
- input_ids, attention_mask, labels = [], [], []
- for message in messages:
- content_str = self.render_message(message)
- content_ids = tokenizer.encode(content_str, add_special_tokens=False)
- input_ids += content_ids
- attention_mask += [1] * len(content_ids)
-
- if hasattr(message, "loss_weight"):
- loss_weight = message["loss_weight"]
- else:
- loss_weight = 1 if message["role"] == "assistant" else 0
- if loss_weight == 1:
- labels += content_ids
- else:
- labels += [-100] * len(content_ids)
- model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
- model_inputs.update({"position_ids": list(range(len(input_ids)))})
- model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
- return model_inputs
-
-
-if __name__ == "__main__":
-
- def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
- out = []
- for m in messages:
- role = m["role"]
- content = template._extract_content(m.get("content", ""))
- if role == "assistant":
- reasoning = (m.get("reasoning_content") or "").strip()
- if reasoning:
- content = template.thinking_template.format(content=reasoning) + content
- out.append({"role": role, "content": content})
- return out
-
- from transformers import AutoTokenizer
-
- tok = AutoTokenizer.from_pretrained(
- "Qwen/Qwen3-30B-A3B-Thinking-2507",
- trust_remote_code=True,
- )
-
- test_messages = [
- {"role": "system", "content": "You are a helpful assistant."},
- {
- "role": "user",
- "content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}],
- },
- {
- "role": "assistant",
- "reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
- "content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
- },
- ]
-
- template = QwenTemplate()
- rendered_custom = "".join([template.render_message(m) for m in test_messages])
-
- qwen3_messages = to_qwen3_messages(template, test_messages)
- rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
-
- print("==== custom ====")
- print(rendered_custom)
- print("==== hf ====")
- print(rendered_hf)
-
- assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
-
- ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
- ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
- assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"
diff --git a/src/llamafactory/v1/plugins/model_plugins/initialization.py b/src/llamafactory/v1/plugins/model_plugins/initialization.py
index 5e6c8bb99..efb8f22f7 100644
--- a/src/llamafactory/v1/plugins/model_plugins/initialization.py
+++ b/src/llamafactory/v1/plugins/model_plugins/initialization.py
@@ -25,12 +25,12 @@ class InitPlugin(BasePlugin):
return super().__call__()
-@InitPlugin("init_on_meta").register
+@InitPlugin("init_on_meta").register()
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
-@InitPlugin("init_on_rank0").register
+@InitPlugin("init_on_rank0").register()
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
@@ -38,6 +38,6 @@ def init_on_rank0() -> torch.device:
return torch.device(DeviceType.META.value)
-@InitPlugin("init_on_default").register
+@InitPlugin("init_on_default").register()
def init_on_default() -> torch.device:
- return DistributedInterface().current_accelerator
+ return DistributedInterface().current_device
diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/base.py b/src/llamafactory/v1/plugins/model_plugins/kernels/base.py
index d5cd83be6..265986ccc 100644
--- a/src/llamafactory/v1/plugins/model_plugins/kernels/base.py
+++ b/src/llamafactory/v1/plugins/model_plugins/kernels/base.py
@@ -38,17 +38,17 @@ class BaseKernel(ABC):
@classmethod
def get_kernel_id(cls) -> str:
- r"""Returns the unique identifier for the kernel."""
+ """Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
- r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
+ """Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
- r"""Checks if the required dependencies for the kernel are available.
+ """Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
- r"""Applies the kernel optimization to the model.
+ """Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
index 19a2def19..446cd57a6 100644
--- a/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
+++ b/src/llamafactory/v1/plugins/model_plugins/kernels/interface.py
@@ -33,7 +33,7 @@ logger = get_logger(__name__)
def scan_all_kernels():
- r"""Scan all kernels in the ``ops`` directory.
+ """Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
@@ -77,7 +77,7 @@ default_kernels = scan_all_kernels()
def get_default_kernels():
- r"""Get a list of default registered kernel IDs.
+ """Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
@@ -86,7 +86,7 @@ def get_default_kernels():
def apply_kernel(kernel_id: str, **kwargs):
- r"""Applies a specific kernel to the model.
+ """Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
@@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs):
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
+
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
- r"""Plugin for managing kernel optimizations."""
+ """Plugin for managing kernel optimizations."""
pass
-@KernelPlugin("auto").register
+@KernelPlugin("auto").register()
def apply_default_kernels(**kwargs):
- r"""Applies all default registered kernels to the model.
+ """Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
@@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs):
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
+
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
+
apply_kernel(kernel, **kwargs)
+
return kwargs.get("model")
diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py
index 0d84dbec8..62854777d 100644
--- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py
+++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_fused_moe.py
@@ -40,11 +40,11 @@ from ...registry import register_kernel
class GmmFunction(torch.autograd.Function):
- r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
+ """Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod
def forward(ctx, x, weight, group_list):
- r"""Performs the forward pass of Grouped Matrix Multiplication.
+ """Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
- r"""Performs the backward pass of Grouped Matrix Multiplication.
+ """Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function):
- r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
+ """Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod
def forward(ctx, num_experts, *args):
- r"""Performs the forward pass of Hybrid GMM.
+ """Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
- r"""Performs the backward pass of Hybrid GMM.
+ """Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused:
- r"""Container for NPU fused MoE forward functions."""
+ """Container for NPU fused MoE forward functions."""
@staticmethod
def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
- r"""Forward pass for MoE experts using NPU fused operations.
+ """Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
@@ -230,11 +230,11 @@ class NpuMoeFused:
class Qwen3NpuMoeFused:
- r"""Container for Qwen3 NPU fused MoE forward functions."""
+ """Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
- r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
+ """Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
@register_kernel
class NpuFusedMoEKernel(BaseKernel):
- r"""NPU Fused MoE Kernel implementation."""
+ """NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> HFModel:
- r"""Applies the NPU fused MoE kernel to the model.
+ """Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
if target_moe_mapping is None:
return model
+
for module in model.modules():
class_name = module.__class__.__name__
if class_name in target_moe_mapping:
diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py
index e6f82051c..a45077bc0 100644
--- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py
+++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/mlp/npu_swiglu.py
@@ -38,7 +38,7 @@ except ImportError:
def npu_swiglu_forward(self, hidden_state):
- r"""SwiGLU forward pass for NPU.
+ """SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
def _npu_swiglu_glm4_forward(self, hidden_states):
- r"""SwiGLU forward pass for GLM4 on NPU.
+ """SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
- r"""SwiGLU forward pass for Gemma3nText on NPU.
+ """SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
@register_kernel
class NpuSwiGluKernel(BaseKernel):
- r"""NPU Kernel for fused SwiGLU activation."""
+ """NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers
expect_modules = frozenset(
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
@classmethod
def apply(cls, **kwargs) -> "HFModel":
- r"""Applies the NPU fused SwiGLU kernel to the model.
+ """Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py
index 6ce36bb67..35057f451 100644
--- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py
+++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py
@@ -30,7 +30,7 @@ from ...registry import register_kernel
def npu_rms_norm_forward(self, hidden_states):
- r"""NPU forward implementation for RMSNorm.
+ """NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
@@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states):
@register_kernel
class NpuRMSNormKernel(BaseKernel):
- r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
+ """NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
_kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
- r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
+ """Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
@@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel):
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
+
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():
diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py
index b431b5063..2f3e290a0 100644
--- a/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py
+++ b/src/llamafactory/v1/plugins/model_plugins/kernels/ops/rope/npu_rope.py
@@ -40,7 +40,7 @@ except ImportError:
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
- r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
+ """Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
@@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
- r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
+ """Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
@@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
@register_kernel
class NpuRoPEKernel(BaseKernel):
- r"""NPU Kernel for Rotary Position Embedding."""
+ """NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
- r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
+ """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
@@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel):
"""
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
+
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
+
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
@@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel):
_modules.add(module_name)
except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
+
return model
diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
index f6c1984ae..16a5c8e0f 100644
--- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
+++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py
@@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"]
class Registry:
- r"""Registry for managing kernel implementations.
+ """Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }``
"""
@@ -38,8 +38,8 @@ class Registry:
_kernels: dict[str, type[BaseKernel]] = {}
@classmethod
- def register(cls, kernel_cls: type[BaseKernel]):
- r"""Decorator to register a kernel class.
+ def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
+ """Decorator to register a kernel class.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
@@ -47,7 +47,7 @@ class Registry:
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
- type[BaseKernel]: The registered kernel class.
+ type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
@@ -55,6 +55,7 @@ class Registry:
"""
if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
+
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
@@ -73,7 +74,7 @@ class Registry:
@classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
- r"""Retrieves a registered kernel implementation by its ID.
+ """Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
@@ -85,7 +86,7 @@ class Registry:
@classmethod
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
- r"""Returns a dictionary of all registered kernels.
+ """Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
diff --git a/src/llamafactory/v1/plugins/model_plugins/peft.py b/src/llamafactory/v1/plugins/model_plugins/peft.py
index 819b06d14..a3f37482c 100644
--- a/src/llamafactory/v1/plugins/model_plugins/peft.py
+++ b/src/llamafactory/v1/plugins/model_plugins/peft.py
@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
return super().__call__(model, config)
-@PeftPlugin("lora").register
+@PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model
-@PeftPlugin("freeze").register
+@PeftPlugin("freeze").register()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()
diff --git a/src/llamafactory/v1/plugins/model_plugins/rendering.py b/src/llamafactory/v1/plugins/model_plugins/rendering.py
new file mode 100644
index 000000000..7e9d026d2
--- /dev/null
+++ b/src/llamafactory/v1/plugins/model_plugins/rendering.py
@@ -0,0 +1,36 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...utils.plugin import BasePlugin
+from ...utils.types import Message, ModelInput, Processor
+
+
+class RenderingPlugin(BasePlugin):
+ pass
+
+
+@RenderingPlugin("qwen").register("render_messages")
+def render_qwen_messages(
+ processor: Processor,
+ messages: list[Message],
+ tools: str | None = None,
+ is_generate: bool = False,
+) -> ModelInput:
+ raise NotImplementedError()
+
+
+@RenderingPlugin("qwen").register("parse_message")
+def parse_qwen_message(generated_text: str) -> Message:
+ raise NotImplementedError()
diff --git a/src/llamafactory/v1/samplers/cli_sampler.py b/src/llamafactory/v1/samplers/cli_sampler.py
index c08bc838a..d102ac84b 100644
--- a/src/llamafactory/v1/samplers/cli_sampler.py
+++ b/src/llamafactory/v1/samplers/cli_sampler.py
@@ -12,10 +12,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import asyncio
+import os
+from collections.abc import Generator
+from threading import Thread
-from ..config import InputArgument, SampleBackend, get_args
+from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args
from ..core.base_sampler import BaseSampler
-from ..core.model_loader import ModelLoader
+from ..core.data_engine import DataEngine
+from ..core.model_engine import ModelEngine
+from ..core.utils.rendering import Renderer
+from ..utils.types import HFModel, Message, Sample, TorchDataset
+
+
+class SyncSampler(BaseSampler):
+ def __init__(
+ self,
+ args: SampleArguments,
+ model_args: ModelArguments,
+ model: HFModel,
+ renderer: Renderer,
+ ) -> None:
+ def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
+ asyncio.set_event_loop(loop)
+ loop.run_forever()
+
+ super().__init__(args, model_args, model, renderer)
+ self._loop = asyncio.new_event_loop()
+ self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
+ self._thread.start()
+
+ def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
+ """Generate tokens synchronously.
+
+ Args:
+ messages: List of messages.
+ tools: Tools string.
+
+ Yields:
+ Generated tokens.
+ """
+ generator = super().generate(messages, tools)
+ while True:
+ try:
+ token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result()
+ yield token
+ except StopAsyncIteration:
+ break
+
+ def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
+ """Batch infer samples synchronously.
+
+ Args:
+ dataset: Torch dataset.
+
+ Returns:
+ List of samples.
+ """
+ return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result()
def run_chat(args: InputArgument = None):
@@ -23,12 +77,48 @@ def run_chat(args: InputArgument = None):
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}
- model_loader = ModelLoader(model_args)
- sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
+ model_engine = ModelEngine(model_args)
+ sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
if data_args.dataset is not None:
- sampler.batch_infer()
+ dataset = DataEngine(data_args)
+ sampler.batch_infer(dataset)
else:
- sampler.generate()
+ if os.name != "nt":
+ try:
+ import readline # noqa: F401
+ except ImportError:
+ print("Install `readline` for a better experience.")
+
+ messages = []
+ print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
+
+ while True:
+ try:
+ query = input("\nUser: ")
+ except UnicodeDecodeError:
+ print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
+ continue
+ except Exception:
+ raise
+
+ if query.strip() == "exit":
+ break
+
+ if query.strip() == "clear":
+ messages = []
+ print("History has been removed.")
+ continue
+
+ messages.append({"role": "user", "content": [{"type": "text", "value": query}]})
+ print("Assistant: ", end="", flush=True)
+
+ response = ""
+ for new_text in sampler.generate(messages):
+ print(new_text, end="", flush=True)
+ response += new_text
+
+ print()
+ messages.append(model_engine.renderer.parse_message(response))
if __name__ == "__main__":
diff --git a/src/llamafactory/v1/trainers/sft_trainer.py b/src/llamafactory/v1/trainers/sft_trainer.py
index c48c1092f..2cb8d3a3c 100644
--- a/src/llamafactory/v1/trainers/sft_trainer.py
+++ b/src/llamafactory/v1/trainers/sft_trainer.py
@@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
-from ..core.model_loader import ModelLoader
+from ..core.model_engine import ModelEngine
class SFTTrainer(BaseTrainer):
@@ -28,11 +28,11 @@ def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
- model_loader = ModelLoader(model_args)
+ model_engine = ModelEngine(model_args)
trainer = SFTTrainer(
args=training_args,
- model=model_loader.model,
- processor=model_loader.processor,
+ model=model_engine.model,
+ processor=model_engine.processor,
dataset=data_engine,
)
trainer.fit()
diff --git a/src/llamafactory/v1/utils/constants.py b/src/llamafactory/v1/utils/constants.py
index ec0d62554..9ec68b44d 100644
--- a/src/llamafactory/v1/utils/constants.py
+++ b/src/llamafactory/v1/utils/constants.py
@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+IGNORE_INDEX = -100
diff --git a/src/llamafactory/v1/utils/helper.py b/src/llamafactory/v1/utils/helper.py
new file mode 100644
index 000000000..7cbd9f336
--- /dev/null
+++ b/src/llamafactory/v1/utils/helper.py
@@ -0,0 +1,29 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers import PreTrainedTokenizer
+
+from .types import Processor
+
+
+def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
+ """Get tokenizer from processor.
+
+ Args:
+ processor: Processor.
+
+ Returns:
+ Tokenizer.
+ """
+ return processor.tokenizer if hasattr(processor, "tokenizer") else processor
diff --git a/src/llamafactory/v1/utils/logging.py b/src/llamafactory/v1/utils/logging.py
index 81bc53751..4d38927ff 100644
--- a/src/llamafactory/v1/utils/logging.py
+++ b/src/llamafactory/v1/utils/logging.py
@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
def _get_library_name() -> str:
- return __name__.split(".")[0]
+ return ".".join(__name__.split(".")[:2]) # llamafactory.v1
def _get_library_root_logger() -> "_Logger":
diff --git a/src/llamafactory/v1/utils/plugin.py b/src/llamafactory/v1/utils/plugin.py
index 0e4bcdf8e..136896492 100644
--- a/src/llamafactory/v1/utils/plugin.py
+++ b/src/llamafactory/v1/utils/plugin.py
@@ -13,6 +13,7 @@
# limitations under the License.
+from collections import defaultdict
from collections.abc import Callable
from . import logging
@@ -27,7 +28,7 @@ class BasePlugin:
A plugin is a callable object that can be registered and called by name.
"""
- _registry: dict[str, Callable] = {}
+ _registry: dict[str, dict[str, Callable]] = defaultdict(dict)
def __init__(self, name: str | None = None):
"""Initialize the plugin with a name.
@@ -37,8 +38,7 @@ class BasePlugin:
"""
self.name = name
- @property
- def register(self):
+ def register(self, method_name: str = "__call__"):
"""Decorator to register a function as a plugin.
Example usage:
@@ -46,16 +46,21 @@ class BasePlugin:
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
+
+
+ @PrintPlugin("hello").register("again")
+ def print_hello_again():
+ print("Hello world! Again.")
```
"""
if self.name is None:
- raise ValueError("Plugin name is not specified.")
+ raise ValueError("Plugin name should be specified.")
- if self.name in self._registry:
- logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
+ if method_name in self._registry[self.name]:
+ logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable:
- self._registry[self.name] = func
+ self._registry[self.name][method_name] = func
return func
return decorator
@@ -68,10 +73,23 @@ class BasePlugin:
PrintPlugin("hello")()
```
"""
- if self.name not in self._registry:
- raise ValueError(f"Plugin {self.name} is not registered.")
+ if "__call__" not in self._registry[self.name]:
+ raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
- return self._registry[self.name](*args, **kwargs)
+ return self._registry[self.name]["__call__"](*args, **kwargs)
+
+ def __getattr__(self, method_name: str):
+ """Get the registered function with the given name.
+
+ Example usage:
+ ```python
+ PrintPlugin("hello").again()
+ ```
+ """
+ if method_name not in self._registry[self.name]:
+ raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
+
+ return self._registry[self.name][method_name]
if __name__ == "__main__":
@@ -82,8 +100,13 @@ if __name__ == "__main__":
class PrintPlugin(BasePlugin):
pass
- @PrintPlugin("hello").register
+ @PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
+ @PrintPlugin("hello").register("again")
+ def print_hello_again():
+ print("Hello world! Again.")
+
PrintPlugin("hello")()
+ PrintPlugin("hello").again()
diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py
index bae334f4a..cad3d57a1 100644
--- a/src/llamafactory/v1/utils/types.py
+++ b/src/llamafactory/v1/utils/types.py
@@ -84,27 +84,59 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict):
- type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
+ type: Literal["text", "reasoning", "tool_call", "image_url"]
+ """Type of the content."""
value: str
+ """Value of the content."""
class Message(TypedDict):
role: Literal["system", "user", "assistant", "tool"]
+ """Role of the message."""
content: list[Content]
- loss_weight: float
+ """Content of the message."""
+ loss_weight: NotRequired[float]
+ """Loss weight for this message, default to 1.0. Required in training."""
class SFTSample(TypedDict):
messages: list[Message]
+ """Messages in the sample."""
extra_info: NotRequired[str]
+ """Extra information for the sample, including tools, kto_labels."""
_dataset_name: NotRequired[str]
+ """Dataset name for the sample."""
class DPOSample(TypedDict):
chosen_messages: list[Message]
+ """Chosen messages in the sample."""
rejected_messages: list[Message]
+ """Rejected messages in the sample."""
extra_info: NotRequired[str]
+ """Extra information for the sample, including tools, kto_labels."""
_dataset_name: NotRequired[str]
+ """Dataset name for the sample."""
Sample = Union[SFTSample, DPOSample]
+
+
+class ToolCall(TypedDict):
+ name: str
+ """Function name."""
+ arguments: str
+ """Function arguments."""
+
+
+class ModelInput(TypedDict, total=False):
+ input_ids: list[int]
+ """Input ids for the model."""
+ attention_mask: list[int]
+ """Attention mask for the model."""
+ labels: list[int]
+ """Labels for the model."""
+ loss_weights: list[float]
+ """Loss weight for each token, default to 1.0."""
+ position_ids: NotRequired[list[int] | list[list[int]]]
+ """Position ids for the model (optional)."""
diff --git a/tests/conftest.py b/tests/conftest.py
index cd20a0d2b..835a15980 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -18,7 +18,7 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
-from typing import Optional
+import sys
import pytest
import torch
@@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow)
-def _get_visible_devices_env() -> Optional[str]:
+def _get_visible_devices_env() -> str | None:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
@@ -149,6 +149,14 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
+
+ # add project root dir to path for mp run
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+ if project_root not in sys.path:
+ sys.path.insert(0, project_root)
+
+ os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
+
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]
diff --git a/tests/version.txt b/tests/version.txt
index 08eb069fe..53686d274 100644
--- a/tests/version.txt
+++ b/tests/version.txt
@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
-0.9.4.105
+0.9.5.101
diff --git a/tests_v1/core/test_data_loader.py b/tests_v1/core/test_data_loader.py
deleted file mode 100644
index 329098242..000000000
--- a/tests_v1/core/test_data_loader.py
+++ /dev/null
@@ -1,173 +0,0 @@
-# Copyright 2025 the LlamaFactory team.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
-
-Tests the 4 scenarios:
-a) non pack + non dynamic.
-b) non pack + dynamic.
-c) pack + non dynamic.
-d) pack + dynamic.
-"""
-
-import torch
-from torch.utils.data import DataLoader as TorchDataLoader
-from torch.utils.data import Dataset
-from transformers import AutoTokenizer
-
-from llamafactory.v1.config.data_args import DataArguments
-from llamafactory.v1.core.data_engine import DataEngine
-from llamafactory.v1.core.trainer_utils.data_collator import (
- DefaultCollator,
-)
-from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
-from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
-from llamafactory.v1.utils.batching_queue import TextBatchingQueue
-
-
-class TensorDataset(Dataset):
- """Wrapper dataset that converts DataEngine samples to tensor format."""
-
- def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
- self.data_engine = data_engine
- self.processor = processor
- self.template = template
- self.max_samples = max_samples or len(data_engine)
- self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
-
- def __len__(self):
- return min(self.max_samples, len(self.data_engine))
-
- def __getitem__(self, idx):
- # Get sample from DataEngine
- sample = self.data_engine[idx]
-
- # Extract messages from sample
- # DataEngine returns samples with format like {"messages": [...], ...}
- # For llamafactory/v1-sft-demo, the format should have "messages" field
- messages = None
- if "messages" in sample:
- messages = sample["messages"]
- elif "conversations" in sample:
- messages = sample["conversations"]
- elif "conversation" in sample:
- messages = sample["conversation"]
- else:
- # Try to find message-like fields (skip _dataset_name)
- for key, value in sample.items():
- if key.startswith("_"):
- continue
- if isinstance(value, list) and len(value) > 0:
- # Check if it looks like a message list
- if isinstance(value[0], dict) and "role" in value[0]:
- messages = value
- break
-
- if messages is None:
- raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
-
- # Encode messages using template
- encoded = self.template.encode_messages(self.tokenizer, messages)
-
- # Convert to tensors
- return {
- "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
- "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
- "labels": torch.tensor(encoded["labels"], dtype=torch.long),
- }
-
-
-def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
- """Create a real dataset using DataEngine."""
- data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
- data_engine = DataEngine(data_args)
-
- # Create processor and template
- processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
- template = QwenTemplate()
-
- # Create tensor dataset
- raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
-
- # Create torch DataLoader
- torch_dataloader = TorchDataLoader(
- raw_data_dataset,
- batch_size=batch_size,
- shuffle=False,
- collate_fn=lambda x: x,
- )
-
- return torch_dataloader, processor, template
-
-
-class TestDataLoaderNonPackNonDynamic:
- """Test case a) non pack + non dynamic."""
-
- def test_basic_functionality(self):
- """Test DataLoader without packing and without dynamic batching."""
- # Create real dataset
- torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
-
- # Create collator (non-packing)
- collator = DefaultCollator(processor=processor, template=template)
-
- # Create DataLoader without batching_queue (non-dynamic)
- data_loader = DataLoader(
- dataloader=torch_dataloader,
- collate_fn=collator,
- num_micro_batch=1,
- batching_queue=None,
- )
-
- # Iterate and check results
- batches = list(iter(data_loader))
- assert len(batches) > 0
-
- # Check first batch
- one_batch = batches[0]
- micro_batches = one_batch[0]
- assert "input_ids" in micro_batches
- assert "attention_mask" in micro_batches
- assert "labels" in micro_batches
- assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
- assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
-
-
-class TestDataLoaderNonPackDynamic:
- """Test case b) non pack + dynamic."""
-
- def test_basic_functionality(self):
- """Test DataLoader without packing but with dynamic batching."""
- # Create real dataset
- torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
- collator = DefaultCollator(processor=processor, template=template)
-
- # Create batching queue for dynamic batching
- batching_queue = TextBatchingQueue(
- token_micro_bsz=120,
- buffer_size=8,
- )
-
- data_loader = DataLoader(
- dataloader=torch_dataloader,
- collate_fn=collator,
- num_micro_batch=4,
- batching_queue=batching_queue,
- )
-
- # Iterate and check
- batches = list(iter(data_loader))
- micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
- assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
- assert len(batches) > 0
diff --git a/tests_v1/core/test_model_loader.py b/tests_v1/core/test_model_loader.py
index fa038229e..5f0150b63 100644
--- a/tests_v1/core/test_model_loader.py
+++ b/tests_v1/core/test_model_loader.py
@@ -15,18 +15,18 @@
import torch
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
-from llamafactory.v1.core.model_loader import ModelLoader
+from llamafactory.v1.core.model_engine import ModelEngine
def test_tiny_qwen():
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
- model_loader = ModelLoader(model_args)
- assert isinstance(model_loader.processor, Qwen2TokenizerFast)
- assert isinstance(model_loader.model.config, Qwen2Config)
- assert isinstance(model_loader.model, Qwen2ForCausalLM)
- assert model_loader.model.dtype == torch.bfloat16
+ model_engine = ModelEngine(model_args)
+ assert isinstance(model_engine.processor, Qwen2TokenizerFast)
+ assert isinstance(model_engine.model_config, Qwen2Config)
+ assert isinstance(model_engine.model, Qwen2ForCausalLM)
+ assert model_engine.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
@@ -37,13 +37,14 @@ def test_tiny_qwen_with_kernel_plugin():
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
)
- model_loader = ModelLoader(model_args)
+ model_engine = ModelEngine(model_args)
# test enable apply kernel plugin
if hasattr(torch, "npu"):
- assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
+ assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else:
- assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
- assert isinstance(model_loader.model, Qwen2ForCausalLM)
+ assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
+
+ assert isinstance(model_engine.model, Qwen2ForCausalLM)
if __name__ == "__main__":
diff --git a/tests_v1/core/utils/test_data_loader.py b/tests_v1/core/utils/test_data_loader.py
new file mode 100644
index 000000000..cbddf0887
--- /dev/null
+++ b/tests_v1/core/utils/test_data_loader.py
@@ -0,0 +1,171 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
+
+Tests the 4 scenarios:
+a) non pack + non dynamic.
+b) non pack + dynamic.
+c) pack + non dynamic.
+d) pack + dynamic.
+"""
+
+# import torch
+# from torch.utils.data import DataLoader as TorchDataLoader
+# from torch.utils.data import Dataset
+# from transformers import AutoTokenizer
+
+# from llamafactory.v1.config.data_args import DataArguments
+# from llamafactory.v1.core.data_engine import DataEngine
+# from llamafactory.v1.core.utils.data_collator import DefaultCollator
+# from llamafactory.v1.core.utils.data_loader import DataLoader
+# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
+# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
+
+
+# class TensorDataset(Dataset):
+# """Wrapper dataset that converts DataEngine samples to tensor format."""
+
+# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
+# self.data_engine = data_engine
+# self.processor = processor
+# self.template = template
+# self.max_samples = max_samples or len(data_engine)
+# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
+
+# def __len__(self):
+# return min(self.max_samples, len(self.data_engine))
+
+# def __getitem__(self, idx):
+# # Get sample from DataEngine
+# sample = self.data_engine[idx]
+
+# # Extract messages from sample
+# # DataEngine returns samples with format like {"messages": [...], ...}
+# # For llamafactory/v1-sft-demo, the format should have "messages" field
+# messages = None
+# if "messages" in sample:
+# messages = sample["messages"]
+# elif "conversations" in sample:
+# messages = sample["conversations"]
+# elif "conversation" in sample:
+# messages = sample["conversation"]
+# else:
+# # Try to find message-like fields (skip _dataset_name)
+# for key, value in sample.items():
+# if key.startswith("_"):
+# continue
+# if isinstance(value, list) and len(value) > 0:
+# # Check if it looks like a message list
+# if isinstance(value[0], dict) and "role" in value[0]:
+# messages = value
+# break
+
+# if messages is None:
+# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
+
+# # Encode messages using template
+# encoded = self.template.encode_messages(self.tokenizer, messages)
+
+# # Convert to tensors
+# return {
+# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
+# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
+# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
+# }
+
+
+# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
+# """Create a real dataset using DataEngine."""
+# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
+# data_engine = DataEngine(data_args)
+
+# # Create processor and template
+# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
+# template = QwenTemplate()
+
+# # Create tensor dataset
+# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
+
+# # Create torch DataLoader
+# torch_dataloader = TorchDataLoader(
+# raw_data_dataset,
+# batch_size=batch_size,
+# shuffle=False,
+# collate_fn=lambda x: x,
+# )
+
+# return torch_dataloader, processor, template
+
+
+# class TestDataLoaderNonPackNonDynamic:
+# """Test case a) non pack + non dynamic."""
+
+# def test_basic_functionality(self):
+# """Test DataLoader without packing and without dynamic batching."""
+# # Create real dataset
+# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
+
+# # Create collator (non-packing)
+# collator = DefaultCollator(processor=processor, template=template)
+
+# # Create DataLoader without batching_queue (non-dynamic)
+# data_loader = DataLoader(
+# dataloader=torch_dataloader,
+# collate_fn=collator,
+# num_micro_batch=1,
+# batching_queue=None,
+# )
+
+# # Iterate and check results
+# batches = list(iter(data_loader))
+# assert len(batches) > 0
+
+# # Check first batch
+# one_batch = batches[0]
+# micro_batches = one_batch[0]
+# assert "input_ids" in micro_batches
+# assert "attention_mask" in micro_batches
+# assert "labels" in micro_batches
+# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
+# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
+
+
+# class TestDataLoaderNonPackDynamic:
+# """Test case b) non pack + dynamic."""
+
+# def test_basic_functionality(self):
+# """Test DataLoader without packing but with dynamic batching."""
+# # Create real dataset
+# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
+# collator = DefaultCollator(processor=processor, template=template)
+
+# # Create batching queue for dynamic batching
+# batching_queue = TextBatchingQueue(
+# token_micro_bsz=120,
+# buffer_size=8,
+# )
+
+# data_loader = DataLoader(
+# dataloader=torch_dataloader,
+# collate_fn=collator,
+# num_micro_batch=4,
+# batching_queue=batching_queue,
+# )
+
+# # Iterate and check
+# batches = list(iter(data_loader))
+# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
+# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
+# assert len(batches) > 0
diff --git a/tests_v1/core/utils/test_rendering.py b/tests_v1/core/utils/test_rendering.py
new file mode 100644
index 000000000..bca3d00c9
--- /dev/null
+++ b/tests_v1/core/utils/test_rendering.py
@@ -0,0 +1,65 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers import AutoTokenizer
+
+from llamafactory.v1.core.utils.rendering import Renderer
+from llamafactory.v1.utils.types import Processor
+
+
+HF_MESSAGES = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is LLM?"},
+ {"role": "assistant", "content": "LLM stands for Large Language Model."},
+]
+V1_MESSAGES = [
+ {"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
+ {"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
+ {"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
+]
+
+
+def test_chatml_rendering():
+ tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
+ renderer = Renderer(template="chatml", processor=tokenizer)
+
+ hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True)
+ v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
+ assert v1_inputs["input_ids"] == hf_inputs
+ assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
+ assert v1_inputs["labels"] == [-100] * len(hf_inputs)
+ assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
+
+ hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False)
+ hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False)
+ v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
+ assert v1_inputs_full["input_ids"] == hf_inputs_full
+ assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
+ assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
+ assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
+ len(hf_inputs_full) - len(hf_inputs_part)
+ )
+
+
+def test_chatml_parse():
+ tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
+ renderer = Renderer(template="chatml", processor=tokenizer)
+ generated_text = "LLM stands for Large Language Model."
+ parsed_message = renderer.parse_message(generated_text)
+ assert parsed_message == V1_MESSAGES[-1]
+
+
+if __name__ == "__main__":
+ test_chatml_rendering()
+ test_chatml_parse()
diff --git a/tests_v1/plugins/data_plugins/test_converter.py b/tests_v1/plugins/data_plugins/test_converter.py
index 0ad40aa19..2a12cf55d 100644
--- a/tests_v1/plugins/data_plugins/test_converter.py
+++ b/tests_v1/plugins/data_plugins/test_converter.py
@@ -54,7 +54,7 @@ def test_sharegpt_converter():
"conversations": [
{"from": "system", "value": "System"},
{"from": "human", "value": "User"},
- {"from": "function_call", "value": "Tool"},
+ {"from": "function_call", "value": "1"},
{"from": "observation", "value": "Observation"},
{"from": "gpt", "value": "Assistant"},
]
@@ -63,7 +63,7 @@ def test_sharegpt_converter():
"messages": [
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
- {"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"},
+ {"content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0, "role": "assistant"},
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
]
diff --git a/tests_v1/plugins/model_plugins/test_init_plugin.py b/tests_v1/plugins/model_plugins/test_init_plugin.py
index 80e9d178b..d13294c8e 100644
--- a/tests_v1/plugins/model_plugins/test_init_plugin.py
+++ b/tests_v1/plugins/model_plugins/test_init_plugin.py
@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import pytest
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
-from llamafactory.v1.core.model_loader import ModelLoader
+from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta():
@@ -26,11 +25,10 @@ def test_init_on_meta():
init_config={"name": "init_on_meta"},
)
)
- model_loader = ModelLoader(model_args=model_args)
- assert model_loader.model.device.type == "meta"
+ model_engine = ModelEngine(model_args=model_args)
+ assert model_engine.model.device.type == "meta"
-@pytest.mark.runs_on(["cuda", "npu"])
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
@@ -38,11 +36,11 @@ def test_init_on_rank0():
init_config={"name": "init_on_rank0"},
)
)
- model_loader = ModelLoader(model_args=model_args)
+ model_engine = ModelEngine(model_args=model_args)
if DistributedInterface().get_rank() == 0:
- assert model_loader.model.device.type == "cpu"
+ assert model_engine.model.device.type == "cpu"
else:
- assert model_loader.model.device.type == "meta"
+ assert model_engine.model.device.type == "meta"
def test_init_on_default():
@@ -52,5 +50,5 @@ def test_init_on_default():
init_config={"name": "init_on_default"},
)
)
- model_loader = ModelLoader(model_args=model_args)
- assert model_loader.model.device.type == DistributedInterface().current_accelerator.type
+ model_engine = ModelEngine(model_args=model_args)
+ assert model_engine.model.device == DistributedInterface().current_device
diff --git a/tests_v1/sampler/test_cli_sampler.py b/tests_v1/sampler/test_cli_sampler.py
new file mode 100644
index 000000000..cd20432a3
--- /dev/null
+++ b/tests_v1/sampler/test_cli_sampler.py
@@ -0,0 +1,41 @@
+# Copyright 2025 the LlamaFactory team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from llamafactory.v1.config import ModelArguments, SampleArguments
+from llamafactory.v1.core.model_engine import ModelEngine
+from llamafactory.v1.samplers.cli_sampler import SyncSampler
+
+
+@pytest.mark.runs_on(["cuda", "npu"])
+def test_sync_sampler():
+ model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507")
+ sample_args = SampleArguments()
+ model_engine = ModelEngine(model_args)
+ sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
+ messages = [{"role": "user", "content": [{"type": "text", "value": "Say 'This is a test.'"}]}]
+ response = ""
+ for new_text in sampler.generate(messages):
+ response += new_text
+
+ print(response)
+ assert model_engine.renderer.parse_message(response) == {
+ "role": "assistant",
+ "content": [{"type": "text", "value": "This is a test."}],
+ }
+
+
+if __name__ == "__main__":
+ test_sync_sampler()