[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -55,12 +55,12 @@ jobs:
uv pip install -e .
uv pip install -r requirements/dev.txt
- name: Cache HuggingFace models
- name: Cache files
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality
run: |

View File

@@ -73,7 +73,7 @@ dependencies = [
# api
"uvicorn",
"fastapi",
"sse-starlette"
"sse-starlette",
]
[project.scripts]

View File

@@ -119,9 +119,19 @@ def synchronize() -> None:
@requires_accelerator
def set_device() -> None:
"""Set current accelerator."""
torch.accelerator.set_device_index(get_local_rank())
def set_device_index() -> None:
"""Set current accelerator index to local rank."""
if get_current_accelerator().type != DeviceType.CPU:
torch.accelerator.set_device_index(get_local_rank())
@requires_accelerator
def get_current_device() -> torch.device:
"""Get current accelerator device."""
if get_current_accelerator().type == DeviceType.CPU:
return torch.device(DeviceType.CPU.value)
else:
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
def is_torch_cuda_available():

View File

@@ -123,12 +123,13 @@ class DistributedInterface:
if self._initialized:
return
helper.set_device_index()
self._is_distributed = helper.is_distributed()
self._rank = helper.get_rank()
self._world_size = helper.get_world_size()
self._local_rank = helper.get_local_rank()
self._local_world_size = helper.get_local_world_size()
self.current_accelerator = helper.get_current_accelerator()
self.current_device = helper.get_current_device()
self.device_count = helper.get_device_count()
if config is None:
@@ -144,15 +145,14 @@ class DistributedInterface:
timeout = config.get("timeout", 18000)
if self._is_distributed:
helper.set_device()
init_process_group(timeout=timedelta(seconds=timeout))
self.model_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
device_type=self.current_device.type,
mesh_shape=self.strategy.model_mesh_shape,
mesh_dim_names=self.strategy.model_mesh_dim_names,
)
self.data_device_mesh = init_device_mesh(
device_type=self.current_accelerator.type,
device_type=self.current_device.type,
mesh_shape=self.strategy.data_mesh_shape,
mesh_dim_names=self.strategy.data_mesh_dim_names,
)
@@ -161,12 +161,12 @@ class DistributedInterface:
self.data_device_mesh = None
self._initialized = True
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
logger.info_rank0(f"DistributedInterface initialized: {self}.")
def __str__(self) -> str:
return (
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
@@ -251,4 +251,7 @@ class DistributedInterface:
if __name__ == "__main__":
print(DistributedInterface(DistributedStrategy()))
"""
python -m llamafactory.v1.accelerator.interface
"""
print(DistributedInterface())

View File

@@ -17,7 +17,7 @@
import json
from enum import Enum, unique
from enum import StrEnum, unique
class PluginConfig(dict):
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
@unique
class ModelClass(str, Enum):
class ModelClass(StrEnum):
"""Auto class for model config."""
LLM = "llm"
@@ -45,7 +45,7 @@ class ModelClass(str, Enum):
@unique
class SampleBackend(str, Enum):
class SampleBackend(StrEnum):
HF = "hf"
VLLM = "vllm"

View File

@@ -21,8 +21,13 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@dataclass
class ModelArguments:
model: str = field(
default="Qwen/Qwen3-4B-Instruct-2507",
metadata={"help": "Path to the model or model identifier from Hugging Face."},
)
template: str = field(
default="chatml",
metadata={"help": "Template for the model."},
)
trust_remote_code: bool = field(
default=False,
metadata={"help": "Trust remote code from Hugging Face."},

View File

@@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from threading import Thread
import torch
from transformers import TextIteratorStreamer
from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Processor, TorchDataset
from ..utils.helper import get_tokenizer
from ..utils.types import HFModel, Message, Sample, TorchDataset
from .utils.rendering import Renderer
class BaseEngine(ABC):
@@ -24,8 +34,8 @@ class BaseEngine(ABC):
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel = None,
processor: Processor = None,
model: HFModel,
renderer: Renderer,
) -> None:
"""Initialize the engine.
@@ -33,17 +43,34 @@ class BaseEngine(ABC):
args: Sample arguments.
model_args: Model arguments.
model: Model.
processor: Processor.
renderer: Renderer.
"""
...
@abstractmethod
async def generate(self, messages):
pass
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@abstractmethod
async def batch_infer(self, data: TorchDataset) -> None:
pass
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class HuggingFaceEngine(BaseEngine):
@@ -52,26 +79,103 @@ class HuggingFaceEngine(BaseEngine):
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
renderer: Renderer,
) -> None:
self.args = args
self.model_args = model_args
self.model = model
self.renderer = renderer
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = TextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
response = self.get_response(messages, tools)
while True:
try:
yield await asyncio.to_thread(response)
except StopAsyncIteration:
break
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise NotImplementedError("Batch infer is not implemented.")
class BaseSampler:
"""Base sampler.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
renderer: Renderer,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, processor)
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages):
return await self.engine.generate(messages)
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
async def batch_infer(self, data: TorchDataset) -> None:
return await self.engine.batch_infer(data)
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
async for token in self.engine.generate(messages, tools):
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return await self.engine.batch_infer(dataset)

View File

@@ -14,15 +14,23 @@
"""The definition of data engine.
Init Data engine:
How to use:
data_engine = DataEngine(data_args)
data_engine[i]: Get the sample via index.
Init workflow:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
Get data sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
Note:
1. The data engine is equivalent to the torch dataset.
2. The data engine is agnostic to the model used.
"""
import os
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin
from ..plugins.data_plugins.loader import DataIndexPlugin
if size or weight:
from ..plugins.data_plugins.loader import adjust_data_index
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
data_index = adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin
from ..plugins.data_plugins.loader import select_data_sample
selected_index = DataSelectorPlugin().select(self.data_index, index)
selected_index = select_data_sample(self.data_index, index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)

View File

@@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model loader.
"""The definition of model engine.
How to use:
model_loader = ModelLoader(model_args, is_trainable=True)
model_loader.processor: Get the tokenizer or multi-modal processor.
model_loader.model_config: Get the model configuration.
model_loader.model: Get the HF model.
model_engine = ModelEngine(model_args, is_train=True)
model_engine.processor: Get the tokenizer or multi-modal processor.
model_engine.renderer: Get the renderer.
model_engine.model_config: Get the model configuration.
model_engine.model: Get the HF model.
Init Workflow:
Init workflow:
1. Init processor.
2. Init render.
2. Init model config.
3. Init model.
4. Init adapter.
@@ -36,17 +38,18 @@ from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
class ModelLoader:
"""Model loader.
class ModelEngine:
"""Model engine.
Args:
model_args: Model arguments.
is_trainable: Whether to train the model.
is_train: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
@@ -56,6 +59,8 @@ class ModelLoader:
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
self.renderer = Renderer(self.args.template, self.processor)
"""Renderer."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
@@ -107,7 +112,7 @@ class ModelLoader:
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_accelerator
init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META:
with init_empty_weights():
@@ -144,12 +149,12 @@ class ModelLoader:
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args)
print(model_loader.processor)
print(model_loader.model_config)
print(model_loader.model)
model_engine = ModelEngine(model_args=model_args)
print(model_engine.processor)
print(model_engine.model_config)
print(model_engine.model)

View File

@@ -0,0 +1,239 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def render_chatml_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in ChatML format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)
class Renderer:
def __init__(self, template: str, processor: Processor):
self.template = template
self.processor = processor
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)

View File

@@ -49,6 +49,11 @@ def launch():
run_sft()
elif command == "chat":
from .samplers.cli_sampler import run_chat
run_chat()
elif command == "env":
print_env()

View File

@@ -13,11 +13,12 @@
# limitations under the License.
import json
from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging
from ...utils.plugin import BasePlugin
from ...utils.types import DPOSample, Sample, SFTSample
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
logger = logging.get_logger(__name__)
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
return super().__call__(raw_sample)
@DataConverterPlugin("alpaca").register
@DataConverterPlugin("alpaca").register()
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
"""Convert Alpaca sample to SFT sample.
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
return {"messages": messages}
@DataConverterPlugin("sharegpt").register
@DataConverterPlugin("sharegpt").register()
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"""Convert ShareGPT sample to SFT sample.
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
"function_call": "assistant",
}
messages = []
tools = raw_sample.get("tools", "")
tools = raw_sample.get("tools")
if tools:
try:
tools: list[dict[str, Any]] = json.loads(tools)
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
tools = []
for message in raw_sample.get("conversations", []):
tag = message["from"]
if tag not in tag_mapping:
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
elif tag == "function_call":
try:
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
except json.JSONDecodeError:
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
continue
if not isinstance(tool_calls, list):
tool_calls = [tool_calls]
messages.append(
{
"role": "assistant",
"content": [{"type": "tool_calls", "value": message["value"]}],
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
"loss_weight": 1.0,
}
)
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
)
if tools:
if messages and messages[0]["role"] == "system":
messages[0]["content"].append({"type": "tools", "value": tools})
else:
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
return {"messages": messages}
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
else:
return {"messages": messages}
@DataConverterPlugin("pair").register
@DataConverterPlugin("pair").register()
def pair_converter(raw_sample: PairSample) -> DPOSample:
"""Convert Pair sample to DPO sample.

View File

@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
raise ValueError(f"Unknown dataset filetype: {filetype}.")
@DataLoaderPlugin("local").register
@DataLoaderPlugin("local").register()
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
if os.path.isdir(filepath):
filetype = _get_builder_name(os.listdir(filepath)[0])
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
return dataset
class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index."""
def adjust_data_index(
data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
def adjust_data_index(
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
size (Optional[int]): Desired dataset size.
weight (Optional[float]): Desired dataset weight.
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = random.choices(data_index, k=size)
Returns:
list[tuple[str, int]]: Adjusted dataset index.
"""
if size is not None:
data_index = random.choices(data_index, k=size)
if weight is not None:
data_index = random.choices(data_index, k=int(len(data_index) * weight))
if weight is not None:
data_index = random.choices(data_index, k=int(len(data_index) * weight))
return data_index
return data_index
class DataSelectorPlugin(BasePlugin):
"""Plugin for selecting dataset samples."""
def select_data_sample(
data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
def select(
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Args:
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
index (Union[slice, list[int], Any]): Index of dataset samples.
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")
Returns:
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
"""
if isinstance(index, slice):
return [data_index[i] for i in range(*index.indices(len(data_index)))]
elif isinstance(index, list):
return [data_index[i] for i in index]
else:
raise ValueError(f"Invalid index type {type(index)}.")

View File

@@ -1,133 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class Template:
user_template: str
assistant_template: str
system_template: str
def render_message(self, message: dict[str, str]) -> str:
return self.user_template.format(**message)
@dataclass
class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
if isinstance(content_data, list):
parts = []
for item in content_data:
if item.get("type") == "text":
parts.append(item.get("value", ""))
elif item.get("type") == "image_url":
pass
return "\n".join(parts).strip()
return ""
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))
if role == "assistant":
reasoning_content = message.get("reasoning_content", "")
if reasoning_content:
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
return self.message_template.format(role="assistant", content=reasoning_content + content)
else:
return self.message_template.format(role=role, content=content)
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
"""Encode one message."""
input_ids, attention_mask, labels = [], [], []
for message in messages:
content_str = self.render_message(message)
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
input_ids += content_ids
attention_mask += [1] * len(content_ids)
if hasattr(message, "loss_weight"):
loss_weight = message["loss_weight"]
else:
loss_weight = 1 if message["role"] == "assistant" else 0
if loss_weight == 1:
labels += content_ids
else:
labels += [-100] * len(content_ids)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
model_inputs.update({"position_ids": list(range(len(input_ids)))})
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
return model_inputs
if __name__ == "__main__":
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
out = []
for m in messages:
role = m["role"]
content = template._extract_content(m.get("content", ""))
if role == "assistant":
reasoning = (m.get("reasoning_content") or "").strip()
if reasoning:
content = template.thinking_template.format(content=reasoning) + content
out.append({"role": role, "content": content})
return out
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
"Qwen/Qwen3-30B-A3B-Thinking-2507",
trust_remote_code=True,
)
test_messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [{"type": "text", "text": "1+1等于几"}, {"type": "text", "text": "2+2等于几"}],
},
{
"role": "assistant",
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
},
]
template = QwenTemplate()
rendered_custom = "".join([template.render_message(m) for m in test_messages])
qwen3_messages = to_qwen3_messages(template, test_messages)
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
print("==== custom ====")
print(rendered_custom)
print("==== hf ====")
print(rendered_hf)
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"

View File

@@ -25,12 +25,12 @@ class InitPlugin(BasePlugin):
return super().__call__()
@InitPlugin("init_on_meta").register
@InitPlugin("init_on_meta").register()
def init_on_meta() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_rank0").register
@InitPlugin("init_on_rank0").register()
def init_on_rank0() -> torch.device:
if DistributedInterface().get_rank() == 0:
return torch.device(DeviceType.CPU.value)
@@ -38,6 +38,6 @@ def init_on_rank0() -> torch.device:
return torch.device(DeviceType.META.value)
@InitPlugin("init_on_default").register
@InitPlugin("init_on_default").register()
def init_on_default() -> torch.device:
return DistributedInterface().current_accelerator
return DistributedInterface().current_device

View File

@@ -38,17 +38,17 @@ class BaseKernel(ABC):
@classmethod
def get_kernel_id(cls) -> str:
r"""Returns the unique identifier for the kernel."""
"""Returns the unique identifier for the kernel."""
return cls._kernel_id
@classmethod
def get_device(cls) -> str:
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
return cls._device
@classmethod
def check_deps(cls) -> bool:
r"""Checks if the required dependencies for the kernel are available.
"""Checks if the required dependencies for the kernel are available.
Returns:
bool: ``True`` if dependencies are met, ``False`` otherwise.
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
@classmethod
@abstractmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the kernel optimization to the model.
"""Applies the kernel optimization to the model.
Args:
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.

View File

@@ -33,7 +33,7 @@ logger = get_logger(__name__)
def scan_all_kernels():
r"""Scan all kernels in the ``ops`` directory.
"""Scan all kernels in the ``ops`` directory.
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
@@ -77,7 +77,7 @@ default_kernels = scan_all_kernels()
def get_default_kernels():
r"""Get a list of default registered kernel IDs.
"""Get a list of default registered kernel IDs.
Returns:
list[str]: List of kernel IDs.
@@ -86,7 +86,7 @@ def get_default_kernels():
def apply_kernel(kernel_id: str, **kwargs):
r"""Applies a specific kernel to the model.
"""Applies a specific kernel to the model.
Args:
kernel_id (str): The ID of the kernel to apply.
@@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs):
kernel = default_kernels.get(kernel_id)
if kernel is None:
raise ValueError(f"Kernel {kernel_id} not found")
kernel.apply(**kwargs)
class KernelPlugin(BasePlugin):
r"""Plugin for managing kernel optimizations."""
"""Plugin for managing kernel optimizations."""
pass
@KernelPlugin("auto").register
@KernelPlugin("auto").register()
def apply_default_kernels(**kwargs):
r"""Applies all default registered kernels to the model.
"""Applies all default registered kernels to the model.
Args:
**kwargs: Keyword arguments passed to the kernel application function.
@@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs):
use_kernels = default_kernels.keys()
else:
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
for kernel in use_kernels:
if kernel not in default_kernels:
raise ValueError(f"Kernel {kernel} not found")
apply_kernel(kernel, **kwargs)
return kwargs.get("model")

View File

@@ -40,11 +40,11 @@ from ...registry import register_kernel
class GmmFunction(torch.autograd.Function):
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
@staticmethod
def forward(ctx, x, weight, group_list):
r"""Performs the forward pass of Grouped Matrix Multiplication.
"""Performs the forward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object to save tensors for backward pass.
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
r"""Performs the backward pass of Grouped Matrix Multiplication.
"""Performs the backward pass of Grouped Matrix Multiplication.
Args:
ctx: Context object containing saved tensors.
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
class HybridGmmFunction(torch.autograd.Function):
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
@staticmethod
def forward(ctx, num_experts, *args):
r"""Performs the forward pass of Hybrid GMM.
"""Performs the forward pass of Hybrid GMM.
Args:
ctx: Context object to save tensors.
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, *grad_outputs):
r"""Performs the backward pass of Hybrid GMM.
"""Performs the backward pass of Hybrid GMM.
Args:
ctx: Context object containing saved tensors.
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
class NpuMoeFused:
r"""Container for NPU fused MoE forward functions."""
"""Container for NPU fused MoE forward functions."""
@staticmethod
def npu_moe_experts_forward(
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
) -> torch.Tensor:
r"""Forward pass for MoE experts using NPU fused operations.
"""Forward pass for MoE experts using NPU fused operations.
Args:
self: The MoE layer instance.
@@ -230,11 +230,11 @@ class NpuMoeFused:
class Qwen3NpuMoeFused:
r"""Container for Qwen3 NPU fused MoE forward functions."""
"""Container for Qwen3 NPU fused MoE forward functions."""
@staticmethod
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
Args:
self: The Qwen3 MoE block instance.
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
@register_kernel
class NpuFusedMoEKernel(BaseKernel):
r"""NPU Fused MoE Kernel implementation."""
"""NPU Fused MoE Kernel implementation."""
_kernel_id = "npu_fused_moe"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> HFModel:
r"""Applies the NPU fused MoE kernel to the model.
"""Applies the NPU fused MoE kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
if target_moe_mapping is None:
return model
for module in model.modules():
class_name = module.__class__.__name__
if class_name in target_moe_mapping:

View File

@@ -38,7 +38,7 @@ except ImportError:
def npu_swiglu_forward(self, hidden_state):
r"""SwiGLU forward pass for NPU.
"""SwiGLU forward pass for NPU.
Args:
self: The MLP layer instance.
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
def _npu_swiglu_glm4_forward(self, hidden_states):
r"""SwiGLU forward pass for GLM4 on NPU.
"""SwiGLU forward pass for GLM4 on NPU.
Args:
self: The GLM4 MLP layer instance.
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
r"""SwiGLU forward pass for Gemma3nText on NPU.
"""SwiGLU forward pass for Gemma3nText on NPU.
Args:
self: The Gemma3nText MLP layer instance.
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
@register_kernel
class NpuSwiGluKernel(BaseKernel):
r"""NPU Kernel for fused SwiGLU activation."""
"""NPU Kernel for fused SwiGLU activation."""
# just support apply to the following module layers
expect_modules = frozenset(
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Applies the NPU fused SwiGLU kernel to the model.
"""Applies the NPU fused SwiGLU kernel to the model.
Args:
**kwargs: Keyword arguments containing the model.

View File

@@ -30,7 +30,7 @@ from ...registry import register_kernel
def npu_rms_norm_forward(self, hidden_states):
r"""NPU forward implementation for RMSNorm.
"""NPU forward implementation for RMSNorm.
Args:
self: RMSNorm module instance with `weight` and `variance_epsilon`.
@@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states):
@register_kernel
class NpuRMSNormKernel(BaseKernel):
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
_kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
Key points:
- Match modules whose class name contains "RMSNorm" (case-insensitive).
@@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel):
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
for name, module in model.named_modules():

View File

@@ -40,7 +40,7 @@ except ImportError:
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
Args:
q (Tensor): Query tensor.
@@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
Args:
q (Tensor): Query tensor.
@@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
@register_kernel
class NpuRoPEKernel(BaseKernel):
r"""NPU Kernel for Rotary Position Embedding."""
"""NPU Kernel for Rotary Position Embedding."""
_kernel_id = "npu_fused_rope"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> "HFModel":
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
This function iterates through the model's modules to find attention layers,
identifies the module where they are defined, and replaces the original
@@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel):
"""
if not cls.check_deps():
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
model = kwargs.get("model", None)
if model is None:
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
_modules = set()
for module in model.modules():
if "Attention" in module.__class__.__name__:
@@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel):
_modules.add(module_name)
except Exception as e:
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
return model

View File

@@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"]
class Registry:
r"""Registry for managing kernel implementations.
"""Registry for managing kernel implementations.
Storage structure: ``{ "kernel_id": Class }``
"""
@@ -38,8 +38,8 @@ class Registry:
_kernels: dict[str, type[BaseKernel]] = {}
@classmethod
def register(cls, kernel_cls: type[BaseKernel]):
r"""Decorator to register a kernel class.
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
"""Decorator to register a kernel class.
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
@@ -47,7 +47,7 @@ class Registry:
kernel_cls (type[BaseKernel]): The kernel class to register.
Returns:
type[BaseKernel]: The registered kernel class.
type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
Raises:
TypeError: If the class does not inherit from :class:`BaseKernel`.
@@ -55,6 +55,7 @@ class Registry:
"""
if not issubclass(kernel_cls, BaseKernel):
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
kernel_id = kernel_cls.get_kernel_id()
device = kernel_cls.get_device()
@@ -73,7 +74,7 @@ class Registry:
@classmethod
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
r"""Retrieves a registered kernel implementation by its ID.
"""Retrieves a registered kernel implementation by its ID.
Args:
kernel_id (str): The ID of the kernel to retrieve.
@@ -85,7 +86,7 @@ class Registry:
@classmethod
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
r"""Returns a dictionary of all registered kernels.
"""Returns a dictionary of all registered kernels.
Returns:
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.

View File

@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
return super().__call__(model, config)
@PeftPlugin("lora").register
@PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model
@PeftPlugin("freeze").register
@PeftPlugin("freeze").register()
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()

View File

@@ -0,0 +1,36 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor
class RenderingPlugin(BasePlugin):
pass
@RenderingPlugin("qwen").register("render_messages")
def render_qwen_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
raise NotImplementedError()
@RenderingPlugin("qwen").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message:
raise NotImplementedError()

View File

@@ -12,10 +12,64 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from collections.abc import Generator
from threading import Thread
from ..config import InputArgument, SampleBackend, get_args
from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args
from ..core.base_sampler import BaseSampler
from ..core.model_loader import ModelLoader
from ..core.data_engine import DataEngine
from ..core.model_engine import ModelEngine
from ..core.utils.rendering import Renderer
from ..utils.types import HFModel, Message, Sample, TorchDataset
class SyncSampler(BaseSampler):
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
renderer: Renderer,
) -> None:
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
asyncio.set_event_loop(loop)
loop.run_forever()
super().__init__(args, model_args, model, renderer)
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
"""Generate tokens synchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
generator = super().generate(messages, tools)
while True:
try:
token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result()
yield token
except StopAsyncIteration:
break
def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples synchronously.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result()
def run_chat(args: InputArgument = None):
@@ -23,12 +77,48 @@ def run_chat(args: InputArgument = None):
if sample_args.sample_backend != SampleBackend.HF:
model_args.init_plugin = {"name": "init_on_meta"}
model_loader = ModelLoader(model_args)
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
model_engine = ModelEngine(model_args)
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
if data_args.dataset is not None:
sampler.batch_infer()
dataset = DataEngine(data_args)
sampler.batch_infer(dataset)
else:
sampler.generate()
if os.name != "nt":
try:
import readline # noqa: F401
except ImportError:
print("Install `readline` for a better experience.")
messages = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
while True:
try:
query = input("\nUser: ")
except UnicodeDecodeError:
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
continue
except Exception:
raise
if query.strip() == "exit":
break
if query.strip() == "clear":
messages = []
print("History has been removed.")
continue
messages.append({"role": "user", "content": [{"type": "text", "value": query}]})
print("Assistant: ", end="", flush=True)
response = ""
for new_text in sampler.generate(messages):
print(new_text, end="", flush=True)
response += new_text
print()
messages.append(model_engine.renderer.parse_message(response))
if __name__ == "__main__":

View File

@@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface
from ..config.arg_parser import get_args
from ..core.base_trainer import BaseTrainer
from ..core.data_engine import DataEngine
from ..core.model_loader import ModelLoader
from ..core.model_engine import ModelEngine
class SFTTrainer(BaseTrainer):
@@ -28,11 +28,11 @@ def run_sft(user_args):
model_args, data_args, training_args, _ = get_args(user_args)
DistributedInterface(training_args.dist_config)
data_engine = DataEngine(data_args)
model_loader = ModelLoader(model_args)
model_engine = ModelEngine(model_args)
trainer = SFTTrainer(
args=training_args,
model=model_loader.model,
processor=model_loader.processor,
model=model_engine.model,
processor=model_engine.processor,
dataset=data_engine,
)
trainer.fit()

View File

@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
IGNORE_INDEX = -100

View File

@@ -0,0 +1,29 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import PreTrainedTokenizer
from .types import Processor
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
"""Get tokenizer from processor.
Args:
processor: Processor.
Returns:
Tokenizer.
"""
return processor.tokenizer if hasattr(processor, "tokenizer") else processor

View File

@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
def _get_library_name() -> str:
return __name__.split(".")[0]
return ".".join(__name__.split(".")[:2]) # llamafactory.v1
def _get_library_root_logger() -> "_Logger":

View File

@@ -13,6 +13,7 @@
# limitations under the License.
from collections import defaultdict
from collections.abc import Callable
from . import logging
@@ -27,7 +28,7 @@ class BasePlugin:
A plugin is a callable object that can be registered and called by name.
"""
_registry: dict[str, Callable] = {}
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
def __init__(self, name: str | None = None):
"""Initialize the plugin with a name.
@@ -37,8 +38,7 @@ class BasePlugin:
"""
self.name = name
@property
def register(self):
def register(self, method_name: str = "__call__"):
"""Decorator to register a function as a plugin.
Example usage:
@@ -46,16 +46,21 @@ class BasePlugin:
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
```
"""
if self.name is None:
raise ValueError("Plugin name is not specified.")
raise ValueError("Plugin name should be specified.")
if self.name in self._registry:
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
if method_name in self._registry[self.name]:
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
def decorator(func: Callable) -> Callable:
self._registry[self.name] = func
self._registry[self.name][method_name] = func
return func
return decorator
@@ -68,10 +73,23 @@ class BasePlugin:
PrintPlugin("hello")()
```
"""
if self.name not in self._registry:
raise ValueError(f"Plugin {self.name} is not registered.")
if "__call__" not in self._registry[self.name]:
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
return self._registry[self.name](*args, **kwargs)
return self._registry[self.name]["__call__"](*args, **kwargs)
def __getattr__(self, method_name: str):
"""Get the registered function with the given name.
Example usage:
```python
PrintPlugin("hello").again()
```
"""
if method_name not in self._registry[self.name]:
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
return self._registry[self.name][method_name]
if __name__ == "__main__":
@@ -82,8 +100,13 @@ if __name__ == "__main__":
class PrintPlugin(BasePlugin):
pass
@PrintPlugin("hello").register
@PrintPlugin("hello").register()
def print_hello():
print("Hello world!")
@PrintPlugin("hello").register("again")
def print_hello_again():
print("Hello world! Again.")
PrintPlugin("hello")()
PrintPlugin("hello").again()

View File

@@ -84,27 +84,59 @@ class DistributedConfig(TypedDict, total=False):
class Content(TypedDict):
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
type: Literal["text", "reasoning", "tool_call", "image_url"]
"""Type of the content."""
value: str
"""Value of the content."""
class Message(TypedDict):
role: Literal["system", "user", "assistant", "tool"]
"""Role of the message."""
content: list[Content]
loss_weight: float
"""Content of the message."""
loss_weight: NotRequired[float]
"""Loss weight for this message, default to 1.0. Required in training."""
class SFTSample(TypedDict):
messages: list[Message]
"""Messages in the sample."""
extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels."""
_dataset_name: NotRequired[str]
"""Dataset name for the sample."""
class DPOSample(TypedDict):
chosen_messages: list[Message]
"""Chosen messages in the sample."""
rejected_messages: list[Message]
"""Rejected messages in the sample."""
extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels."""
_dataset_name: NotRequired[str]
"""Dataset name for the sample."""
Sample = Union[SFTSample, DPOSample]
class ToolCall(TypedDict):
name: str
"""Function name."""
arguments: str
"""Function arguments."""
class ModelInput(TypedDict, total=False):
input_ids: list[int]
"""Input ids for the model."""
attention_mask: list[int]
"""Attention mask for the model."""
labels: list[int]
"""Labels for the model."""
loss_weights: list[float]
"""Loss weight for each token, default to 1.0."""
position_ids: NotRequired[list[int] | list[list[int]]]
"""Position ids for the model (optional)."""

View File

@@ -18,7 +18,7 @@ Contains shared fixtures, pytest configuration, and custom markers.
"""
import os
from typing import Optional
import sys
import pytest
import torch
@@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
item.add_marker(skip_slow)
def _get_visible_devices_env() -> Optional[str]:
def _get_visible_devices_env() -> str | None:
"""Return device visibility env var name."""
if CURRENT_DEVICE == "cuda":
return "CUDA_VISIBLE_DEVICES"
@@ -149,6 +149,14 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
devices_str = ",".join(str(i) for i in range(required))
monkeypatch.setenv(env_key, devices_str)
# add project root dir to path for mp run
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if project_root not in sys.path:
sys.path.insert(0, project_root)
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
else: # non-distributed test
if old_value:
visible_devices = [v for v in old_value.split(",") if v != ""]

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated
0.9.4.105
0.9.5.101

View File

@@ -1,173 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
import torch
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from llamafactory.v1.config.data_args import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.trainer_utils.data_collator import (
DefaultCollator,
)
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
from llamafactory.v1.utils.batching_queue import TextBatchingQueue
class TensorDataset(Dataset):
"""Wrapper dataset that converts DataEngine samples to tensor format."""
def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
self.data_engine = data_engine
self.processor = processor
self.template = template
self.max_samples = max_samples or len(data_engine)
self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
def __len__(self):
return min(self.max_samples, len(self.data_engine))
def __getitem__(self, idx):
# Get sample from DataEngine
sample = self.data_engine[idx]
# Extract messages from sample
# DataEngine returns samples with format like {"messages": [...], ...}
# For llamafactory/v1-sft-demo, the format should have "messages" field
messages = None
if "messages" in sample:
messages = sample["messages"]
elif "conversations" in sample:
messages = sample["conversations"]
elif "conversation" in sample:
messages = sample["conversation"]
else:
# Try to find message-like fields (skip _dataset_name)
for key, value in sample.items():
if key.startswith("_"):
continue
if isinstance(value, list) and len(value) > 0:
# Check if it looks like a message list
if isinstance(value[0], dict) and "role" in value[0]:
messages = value
break
if messages is None:
raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# Encode messages using template
encoded = self.template.encode_messages(self.tokenizer, messages)
# Convert to tensors
return {
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
"labels": torch.tensor(encoded["labels"], dtype=torch.long),
}
def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
"""Create a real dataset using DataEngine."""
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
# Create processor and template
processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
template = QwenTemplate()
# Create tensor dataset
raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# Create torch DataLoader
torch_dataloader = TorchDataLoader(
raw_data_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: x,
)
return torch_dataloader, processor, template
class TestDataLoaderNonPackNonDynamic:
"""Test case a) non pack + non dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing and without dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# Create collator (non-packing)
collator = DefaultCollator(processor=processor, template=template)
# Create DataLoader without batching_queue (non-dynamic)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=1,
batching_queue=None,
)
# Iterate and check results
batches = list(iter(data_loader))
assert len(batches) > 0
# Check first batch
one_batch = batches[0]
micro_batches = one_batch[0]
assert "input_ids" in micro_batches
assert "attention_mask" in micro_batches
assert "labels" in micro_batches
assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
class TestDataLoaderNonPackDynamic:
"""Test case b) non pack + dynamic."""
def test_basic_functionality(self):
"""Test DataLoader without packing but with dynamic batching."""
# Create real dataset
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
collator = DefaultCollator(processor=processor, template=template)
# Create batching queue for dynamic batching
batching_queue = TextBatchingQueue(
token_micro_bsz=120,
buffer_size=8,
)
data_loader = DataLoader(
dataloader=torch_dataloader,
collate_fn=collator,
num_micro_batch=4,
batching_queue=batching_queue,
)
# Iterate and check
batches = list(iter(data_loader))
micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
assert len(batches) > 0

View File

@@ -15,18 +15,18 @@
import torch
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
from llamafactory.v1.core.model_loader import ModelLoader
from llamafactory.v1.core.model_engine import ModelEngine
def test_tiny_qwen():
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
model_loader = ModelLoader(model_args)
assert isinstance(model_loader.processor, Qwen2TokenizerFast)
assert isinstance(model_loader.model.config, Qwen2Config)
assert isinstance(model_loader.model, Qwen2ForCausalLM)
assert model_loader.model.dtype == torch.bfloat16
model_engine = ModelEngine(model_args)
assert isinstance(model_engine.processor, Qwen2TokenizerFast)
assert isinstance(model_engine.model_config, Qwen2Config)
assert isinstance(model_engine.model, Qwen2ForCausalLM)
assert model_engine.model.dtype == torch.bfloat16
def test_tiny_qwen_with_kernel_plugin():
@@ -37,13 +37,14 @@ def test_tiny_qwen_with_kernel_plugin():
model_args = ModelArguments(
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
)
model_loader = ModelLoader(model_args)
model_engine = ModelEngine(model_args)
# test enable apply kernel plugin
if hasattr(torch, "npu"):
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
else:
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_loader.model, Qwen2ForCausalLM)
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
assert isinstance(model_engine.model, Qwen2ForCausalLM)
if __name__ == "__main__":

View File

@@ -0,0 +1,171 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
Tests the 4 scenarios:
a) non pack + non dynamic.
b) non pack + dynamic.
c) pack + non dynamic.
d) pack + dynamic.
"""
# import torch
# from torch.utils.data import DataLoader as TorchDataLoader
# from torch.utils.data import Dataset
# from transformers import AutoTokenizer
# from llamafactory.v1.config.data_args import DataArguments
# from llamafactory.v1.core.data_engine import DataEngine
# from llamafactory.v1.core.utils.data_collator import DefaultCollator
# from llamafactory.v1.core.utils.data_loader import DataLoader
# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
# class TensorDataset(Dataset):
# """Wrapper dataset that converts DataEngine samples to tensor format."""
# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
# self.data_engine = data_engine
# self.processor = processor
# self.template = template
# self.max_samples = max_samples or len(data_engine)
# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
# def __len__(self):
# return min(self.max_samples, len(self.data_engine))
# def __getitem__(self, idx):
# # Get sample from DataEngine
# sample = self.data_engine[idx]
# # Extract messages from sample
# # DataEngine returns samples with format like {"messages": [...], ...}
# # For llamafactory/v1-sft-demo, the format should have "messages" field
# messages = None
# if "messages" in sample:
# messages = sample["messages"]
# elif "conversations" in sample:
# messages = sample["conversations"]
# elif "conversation" in sample:
# messages = sample["conversation"]
# else:
# # Try to find message-like fields (skip _dataset_name)
# for key, value in sample.items():
# if key.startswith("_"):
# continue
# if isinstance(value, list) and len(value) > 0:
# # Check if it looks like a message list
# if isinstance(value[0], dict) and "role" in value[0]:
# messages = value
# break
# if messages is None:
# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
# # Encode messages using template
# encoded = self.template.encode_messages(self.tokenizer, messages)
# # Convert to tensors
# return {
# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
# }
# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
# """Create a real dataset using DataEngine."""
# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
# data_engine = DataEngine(data_args)
# # Create processor and template
# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
# template = QwenTemplate()
# # Create tensor dataset
# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
# # Create torch DataLoader
# torch_dataloader = TorchDataLoader(
# raw_data_dataset,
# batch_size=batch_size,
# shuffle=False,
# collate_fn=lambda x: x,
# )
# return torch_dataloader, processor, template
# class TestDataLoaderNonPackNonDynamic:
# """Test case a) non pack + non dynamic."""
# def test_basic_functionality(self):
# """Test DataLoader without packing and without dynamic batching."""
# # Create real dataset
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# # Create collator (non-packing)
# collator = DefaultCollator(processor=processor, template=template)
# # Create DataLoader without batching_queue (non-dynamic)
# data_loader = DataLoader(
# dataloader=torch_dataloader,
# collate_fn=collator,
# num_micro_batch=1,
# batching_queue=None,
# )
# # Iterate and check results
# batches = list(iter(data_loader))
# assert len(batches) > 0
# # Check first batch
# one_batch = batches[0]
# micro_batches = one_batch[0]
# assert "input_ids" in micro_batches
# assert "attention_mask" in micro_batches
# assert "labels" in micro_batches
# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
# class TestDataLoaderNonPackDynamic:
# """Test case b) non pack + dynamic."""
# def test_basic_functionality(self):
# """Test DataLoader without packing but with dynamic batching."""
# # Create real dataset
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
# collator = DefaultCollator(processor=processor, template=template)
# # Create batching queue for dynamic batching
# batching_queue = TextBatchingQueue(
# token_micro_bsz=120,
# buffer_size=8,
# )
# data_loader = DataLoader(
# dataloader=torch_dataloader,
# collate_fn=collator,
# num_micro_batch=4,
# batching_queue=batching_queue,
# )
# # Iterate and check
# batches = list(iter(data_loader))
# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
# assert len(batches) > 0

View File

@@ -0,0 +1,65 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import AutoTokenizer
from llamafactory.v1.core.utils.rendering import Renderer
from llamafactory.v1.utils.types import Processor
HF_MESSAGES = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is LLM?"},
{"role": "assistant", "content": "LLM stands for Large Language Model."},
]
V1_MESSAGES = [
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
]
def test_chatml_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True)
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False)
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False)
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
len(hf_inputs_full) - len(hf_inputs_part)
)
def test_chatml_parse():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
generated_text = "LLM stands for Large Language Model."
parsed_message = renderer.parse_message(generated_text)
assert parsed_message == V1_MESSAGES[-1]
if __name__ == "__main__":
test_chatml_rendering()
test_chatml_parse()

View File

@@ -54,7 +54,7 @@ def test_sharegpt_converter():
"conversations": [
{"from": "system", "value": "System"},
{"from": "human", "value": "User"},
{"from": "function_call", "value": "Tool"},
{"from": "function_call", "value": "1"},
{"from": "observation", "value": "Observation"},
{"from": "gpt", "value": "Assistant"},
]
@@ -63,7 +63,7 @@ def test_sharegpt_converter():
"messages": [
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
{"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"},
{"content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0, "role": "assistant"},
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
]

View File

@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_loader import ModelLoader
from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta():
@@ -26,11 +25,10 @@ def test_init_on_meta():
init_config={"name": "init_on_meta"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == "meta"
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device.type == "meta"
@pytest.mark.runs_on(["cuda", "npu"])
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
@@ -38,11 +36,11 @@ def test_init_on_rank0():
init_config={"name": "init_on_rank0"},
)
)
model_loader = ModelLoader(model_args=model_args)
model_engine = ModelEngine(model_args=model_args)
if DistributedInterface().get_rank() == 0:
assert model_loader.model.device.type == "cpu"
assert model_engine.model.device.type == "cpu"
else:
assert model_loader.model.device.type == "meta"
assert model_engine.model.device.type == "meta"
def test_init_on_default():
@@ -52,5 +50,5 @@ def test_init_on_default():
init_config={"name": "init_on_default"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device

View File

@@ -0,0 +1,41 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.config import ModelArguments, SampleArguments
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.samplers.cli_sampler import SyncSampler
@pytest.mark.runs_on(["cuda", "npu"])
def test_sync_sampler():
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507")
sample_args = SampleArguments()
model_engine = ModelEngine(model_args)
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
messages = [{"role": "user", "content": [{"type": "text", "value": "Say 'This is a test.'"}]}]
response = ""
for new_text in sampler.generate(messages):
response += new_text
print(response)
assert model_engine.renderer.parse_message(response) == {
"role": "assistant",
"content": [{"type": "text", "value": "This is a test."}],
}
if __name__ == "__main__":
test_sync_sampler()