mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 09:30:34 +08:00
[v1] add cli sampler (#9721)
This commit is contained in:
4
.github/workflows/tests_cuda.yml
vendored
4
.github/workflows/tests_cuda.yml
vendored
@@ -55,12 +55,12 @@ jobs:
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
|
||||
- name: Cache HuggingFace models
|
||||
- name: Cache files
|
||||
id: hf-hub-cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ${{ runner.temp }}/huggingface
|
||||
key: hf-cache-${{ runner.os }}-${{ hashFiles('tests/version.txt') }}
|
||||
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
|
||||
@@ -73,7 +73,7 @@ dependencies = [
|
||||
# api
|
||||
"uvicorn",
|
||||
"fastapi",
|
||||
"sse-starlette"
|
||||
"sse-starlette",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -119,9 +119,19 @@ def synchronize() -> None:
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def set_device() -> None:
|
||||
"""Set current accelerator."""
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
def set_device_index() -> None:
|
||||
"""Set current accelerator index to local rank."""
|
||||
if get_current_accelerator().type != DeviceType.CPU:
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def get_current_device() -> torch.device:
|
||||
"""Get current accelerator device."""
|
||||
if get_current_accelerator().type == DeviceType.CPU:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
else:
|
||||
return torch.device(type=get_current_accelerator().type, index=torch.accelerator.current_device_index())
|
||||
|
||||
|
||||
def is_torch_cuda_available():
|
||||
|
||||
@@ -123,12 +123,13 @@ class DistributedInterface:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
helper.set_device_index()
|
||||
self._is_distributed = helper.is_distributed()
|
||||
self._rank = helper.get_rank()
|
||||
self._world_size = helper.get_world_size()
|
||||
self._local_rank = helper.get_local_rank()
|
||||
self._local_world_size = helper.get_local_world_size()
|
||||
self.current_accelerator = helper.get_current_accelerator()
|
||||
self.current_device = helper.get_current_device()
|
||||
self.device_count = helper.get_device_count()
|
||||
|
||||
if config is None:
|
||||
@@ -144,15 +145,14 @@ class DistributedInterface:
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
helper.set_device()
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
device_type=self.current_device.type,
|
||||
mesh_shape=self.strategy.model_mesh_shape,
|
||||
mesh_dim_names=self.strategy.model_mesh_dim_names,
|
||||
)
|
||||
self.data_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
device_type=self.current_device.type,
|
||||
mesh_shape=self.strategy.data_mesh_shape,
|
||||
mesh_dim_names=self.strategy.data_mesh_dim_names,
|
||||
)
|
||||
@@ -161,12 +161,12 @@ class DistributedInterface:
|
||||
self.data_device_mesh = None
|
||||
|
||||
self._initialized = True
|
||||
logger.info_rank0(f"DistributedInterface initialized with strategy={self.strategy}.")
|
||||
logger.info_rank0(f"DistributedInterface initialized: {self}.")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"DistributedInterface(strategy={self.strategy}), is_distributed={self._is_distributed}, "
|
||||
f"current_accelerator={self.current_accelerator}, rank={self._rank}, world_size={self._world_size}, "
|
||||
f"current_device={self.current_device}, rank={self._rank}, world_size={self._world_size}, "
|
||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||
)
|
||||
|
||||
@@ -251,4 +251,7 @@ class DistributedInterface:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(DistributedInterface(DistributedStrategy()))
|
||||
"""
|
||||
python -m llamafactory.v1.accelerator.interface
|
||||
"""
|
||||
print(DistributedInterface())
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from enum import StrEnum, unique
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
@@ -36,7 +36,7 @@ PluginArgument = PluginConfig | dict | str | None
|
||||
|
||||
|
||||
@unique
|
||||
class ModelClass(str, Enum):
|
||||
class ModelClass(StrEnum):
|
||||
"""Auto class for model config."""
|
||||
|
||||
LLM = "llm"
|
||||
@@ -45,7 +45,7 @@ class ModelClass(str, Enum):
|
||||
|
||||
|
||||
@unique
|
||||
class SampleBackend(str, Enum):
|
||||
class SampleBackend(StrEnum):
|
||||
HF = "hf"
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
@@ -21,8 +21,13 @@ from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model: str = field(
|
||||
default="Qwen/Qwen3-4B-Instruct-2507",
|
||||
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||
)
|
||||
template: str = field(
|
||||
default="chatml",
|
||||
metadata={"help": "Template for the model."},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Trust remote code from Hugging Face."},
|
||||
|
||||
@@ -12,10 +12,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||
from ..utils.types import HFModel, Processor, TorchDataset
|
||||
from ..utils.helper import get_tokenizer
|
||||
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@@ -24,8 +34,8 @@ class BaseEngine(ABC):
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel = None,
|
||||
processor: Processor = None,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
"""Initialize the engine.
|
||||
|
||||
@@ -33,17 +43,34 @@ class BaseEngine(ABC):
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
processor: Processor.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, messages):
|
||||
pass
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self, data: TorchDataset) -> None:
|
||||
pass
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
@@ -52,26 +79,103 @@ class HuggingFaceEngine(BaseEngine):
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model_args = model_args
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
)
|
||||
device = DistributedInterface().current_device
|
||||
kwargs = {
|
||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||
"max_new_tokens": self.args.max_new_tokens,
|
||||
"streamer": streamer,
|
||||
}
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def stream():
|
||||
try:
|
||||
return streamer.__next__()
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
response = self.get_response(messages, tools)
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.to_thread(response)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
raise NotImplementedError("Batch infer is not implemented.")
|
||||
|
||||
|
||||
class BaseSampler:
|
||||
"""Base sampler.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
if args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(args, model_args, model, processor)
|
||||
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
|
||||
|
||||
async def generate(self, messages):
|
||||
return await self.engine.generate(messages)
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
async def batch_infer(self, data: TorchDataset) -> None:
|
||||
return await self.engine.batch_infer(data)
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
async for token in self.engine.generate(messages, tools):
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
return await self.engine.batch_infer(dataset)
|
||||
|
||||
@@ -14,15 +14,23 @@
|
||||
|
||||
"""The definition of data engine.
|
||||
|
||||
Init Data engine:
|
||||
How to use:
|
||||
data_engine = DataEngine(data_args)
|
||||
data_engine[i]: Get the sample via index.
|
||||
|
||||
Init workflow:
|
||||
1. Parse dataset info from arguments.
|
||||
2. Load datasets according to dataset info.
|
||||
3. Build data index (and reweight samples if necessary).
|
||||
|
||||
Get Data Sample:
|
||||
Get data sample:
|
||||
1. Get sample from data index.
|
||||
2. Convert sample to standard format.
|
||||
3. Return sample.
|
||||
|
||||
Note:
|
||||
1. The data engine is equivalent to the torch dataset.
|
||||
2. The data engine is agnostic to the model used.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
|
||||
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if size or weight: # data index plugin
|
||||
from ..plugins.data_plugins.loader import DataIndexPlugin
|
||||
if size or weight:
|
||||
from ..plugins.data_plugins.loader import adjust_data_index
|
||||
|
||||
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
|
||||
data_index = adjust_data_index(data_index, size, weight)
|
||||
|
||||
self.data_index.extend(data_index)
|
||||
|
||||
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
|
||||
dataset_name, sample_index = self.data_index[index]
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
else: # data selector plugin
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
from ..plugins.data_plugins.loader import select_data_sample
|
||||
|
||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||
selected_index = select_data_sample(self.data_index, index)
|
||||
if isinstance(selected_index, list):
|
||||
return [
|
||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
|
||||
@@ -12,16 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model loader.
|
||||
"""The definition of model engine.
|
||||
|
||||
How to use:
|
||||
model_loader = ModelLoader(model_args, is_trainable=True)
|
||||
model_loader.processor: Get the tokenizer or multi-modal processor.
|
||||
model_loader.model_config: Get the model configuration.
|
||||
model_loader.model: Get the HF model.
|
||||
model_engine = ModelEngine(model_args, is_train=True)
|
||||
model_engine.processor: Get the tokenizer or multi-modal processor.
|
||||
model_engine.renderer: Get the renderer.
|
||||
model_engine.model_config: Get the model configuration.
|
||||
model_engine.model: Get the HF model.
|
||||
|
||||
Init Workflow:
|
||||
Init workflow:
|
||||
1. Init processor.
|
||||
2. Init render.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
@@ -36,17 +38,18 @@ from ..accelerator.interface import DistributedInterface
|
||||
from ..config.model_args import ModelArguments, ModelClass
|
||||
from ..utils import logging
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ModelLoader:
|
||||
"""Model loader.
|
||||
class ModelEngine:
|
||||
"""Model engine.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments.
|
||||
is_trainable: Whether to train the model.
|
||||
is_train: Whether to train the model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
||||
@@ -56,6 +59,8 @@ class ModelLoader:
|
||||
"""Whether to train the model."""
|
||||
self.processor = self._init_processor()
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.renderer = Renderer(self.args.template, self.processor)
|
||||
"""Renderer."""
|
||||
self.model_config = self._init_model_config()
|
||||
"""Model configuration."""
|
||||
self.model = self._init_model()
|
||||
@@ -107,7 +112,7 @@ class ModelLoader:
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_accelerator
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
if init_device.type == DeviceType.META:
|
||||
with init_empty_weights():
|
||||
@@ -144,12 +149,12 @@ class ModelLoader:
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
|
||||
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
_, model_args, *_ = get_args()
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
print(model_loader.processor)
|
||||
print(model_loader.model_config)
|
||||
print(model_loader.model)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
print(model_engine.processor)
|
||||
print(model_engine.model_config)
|
||||
print(model_engine.model)
|
||||
239
src/llamafactory/v1/core/utils/rendering.py
Normal file
239
src/llamafactory/v1/core/utils/rendering.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
|
||||
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def render_chatml_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Apply chatml template to messages and convert them to model input.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
|
||||
"""
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n"
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
def parse_chatml_message(generated_text: str) -> Message:
|
||||
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in ChatML format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
|
||||
|
||||
class Renderer:
|
||||
def __init__(self, template: str, processor: Processor):
|
||||
self.template = template
|
||||
self.processor = processor
|
||||
|
||||
def render_messages(
|
||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||
) -> ModelInput:
|
||||
if self.template == "chatml":
|
||||
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||
|
||||
def parse_message(self, generated_text: str) -> Message:
|
||||
if self.template == "chatml":
|
||||
return parse_chatml_message(generated_text)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).parse_message(generated_text)
|
||||
@@ -49,6 +49,11 @@ def launch():
|
||||
|
||||
run_sft()
|
||||
|
||||
elif command == "chat":
|
||||
from .samplers.cli_sampler import run_chat
|
||||
|
||||
run_chat()
|
||||
|
||||
elif command == "env":
|
||||
print_env()
|
||||
|
||||
|
||||
@@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import DPOSample, Sample, SFTSample
|
||||
from ...utils.types import DPOSample, Sample, SFTSample, ToolCall
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -61,7 +62,7 @@ class DataConverterPlugin(BasePlugin):
|
||||
return super().__call__(raw_sample)
|
||||
|
||||
|
||||
@DataConverterPlugin("alpaca").register
|
||||
@DataConverterPlugin("alpaca").register()
|
||||
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
"""Convert Alpaca sample to SFT sample.
|
||||
|
||||
@@ -98,7 +99,7 @@ def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("sharegpt").register
|
||||
@DataConverterPlugin("sharegpt").register()
|
||||
def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"""Convert ShareGPT sample to SFT sample.
|
||||
|
||||
@@ -118,17 +119,32 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
"function_call": "assistant",
|
||||
}
|
||||
messages = []
|
||||
tools = raw_sample.get("tools", "")
|
||||
tools = raw_sample.get("tools")
|
||||
if tools:
|
||||
try:
|
||||
tools: list[dict[str, Any]] = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tools format: {str(tools)}")
|
||||
tools = []
|
||||
|
||||
for message in raw_sample.get("conversations", []):
|
||||
tag = message["from"]
|
||||
if tag not in tag_mapping:
|
||||
logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}")
|
||||
elif tag == "function_call":
|
||||
try:
|
||||
tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"])
|
||||
except json.JSONDecodeError:
|
||||
logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}")
|
||||
continue
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
tool_calls = [tool_calls]
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_calls", "value": message["value"]}],
|
||||
"content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls],
|
||||
"loss_weight": 1.0,
|
||||
}
|
||||
)
|
||||
@@ -142,15 +158,12 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
)
|
||||
|
||||
if tools:
|
||||
if messages and messages[0]["role"] == "system":
|
||||
messages[0]["content"].append({"type": "tools", "value": tools})
|
||||
else:
|
||||
messages.insert(0, {"role": "system", "content": [{"type": "tools", "value": tools}], "loss_weight": 0.0})
|
||||
|
||||
return {"messages": messages}
|
||||
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
|
||||
else:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@DataConverterPlugin("pair").register
|
||||
@DataConverterPlugin("pair").register()
|
||||
def pair_converter(raw_sample: PairSample) -> DPOSample:
|
||||
"""Convert Pair sample to DPO sample.
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def _get_builder_name(path: str) -> Literal["arrow", "csv", "json", "parquet", "
|
||||
raise ValueError(f"Unknown dataset filetype: {filetype}.")
|
||||
|
||||
|
||||
@DataLoaderPlugin("local").register
|
||||
@DataLoaderPlugin("local").register()
|
||||
def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset:
|
||||
if os.path.isdir(filepath):
|
||||
filetype = _get_builder_name(os.listdir(filepath)[0])
|
||||
@@ -66,49 +66,43 @@ def load_data_from_file(filepath: str, split: str, streaming: bool) -> HFDataset
|
||||
return dataset
|
||||
|
||||
|
||||
class DataIndexPlugin(BasePlugin):
|
||||
"""Plugin for adjusting dataset index."""
|
||||
def adjust_data_index(
|
||||
data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
|
||||
def adjust_data_index(
|
||||
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
size (Optional[int]): Desired dataset size.
|
||||
weight (Optional[float]): Desired dataset weight.
|
||||
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
size (Optional[int]): Desired dataset size.
|
||||
weight (Optional[float]): Desired dataset weight.
|
||||
Returns:
|
||||
list[tuple[str, int]]: Adjusted dataset index.
|
||||
"""
|
||||
if size is not None:
|
||||
data_index = random.choices(data_index, k=size)
|
||||
|
||||
Returns:
|
||||
list[tuple[str, int]]: Adjusted dataset index.
|
||||
"""
|
||||
if size is not None:
|
||||
data_index = random.choices(data_index, k=size)
|
||||
if weight is not None:
|
||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||
|
||||
if weight is not None:
|
||||
data_index = random.choices(data_index, k=int(len(data_index) * weight))
|
||||
|
||||
return data_index
|
||||
return data_index
|
||||
|
||||
|
||||
class DataSelectorPlugin(BasePlugin):
|
||||
"""Plugin for selecting dataset samples."""
|
||||
def select_data_sample(
|
||||
data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
|
||||
def select(
|
||||
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||
|
||||
Args:
|
||||
data_index (list[tuple[str, int]]): List of (dataset_name, sample_index).
|
||||
index (Union[slice, list[int], Any]): Index of dataset samples.
|
||||
|
||||
Returns:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||
"""
|
||||
if isinstance(index, slice):
|
||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||
elif isinstance(index, list):
|
||||
return [data_index[i] for i in index]
|
||||
else:
|
||||
raise ValueError(f"Invalid index type {type(index)}.")
|
||||
Returns:
|
||||
Union[tuple[str, int], list[tuple[str, int]]]: Selected dataset samples.
|
||||
"""
|
||||
if isinstance(index, slice):
|
||||
return [data_index[i] for i in range(*index.indices(len(data_index)))]
|
||||
elif isinstance(index, list):
|
||||
return [data_index[i] for i in index]
|
||||
else:
|
||||
raise ValueError(f"Invalid index type {type(index)}.")
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
user_template: str
|
||||
assistant_template: str
|
||||
system_template: str
|
||||
|
||||
def render_message(self, message: dict[str, str]) -> str:
|
||||
return self.user_template.format(**message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QwenTemplate:
|
||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||
|
||||
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
|
||||
if isinstance(content_data, str):
|
||||
return content_data.strip()
|
||||
|
||||
if isinstance(content_data, list):
|
||||
parts = []
|
||||
for item in content_data:
|
||||
if item.get("type") == "text":
|
||||
parts.append(item.get("value", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
pass
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
return ""
|
||||
|
||||
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
|
||||
role = message["role"]
|
||||
content = self._extract_content(message.get("content", ""))
|
||||
|
||||
if role == "assistant":
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if reasoning_content:
|
||||
reasoning_content = self.thinking_template.format(content=str(reasoning_content).strip())
|
||||
return self.message_template.format(role="assistant", content=reasoning_content + content)
|
||||
else:
|
||||
return self.message_template.format(role=role, content=content)
|
||||
|
||||
def encode_messages(self, tokenizer, messages: list[dict[str, str]], max_seq_len: int = 8192) -> any:
|
||||
"""Encode one message."""
|
||||
input_ids, attention_mask, labels = [], [], []
|
||||
for message in messages:
|
||||
content_str = self.render_message(message)
|
||||
content_ids = tokenizer.encode(content_str, add_special_tokens=False)
|
||||
input_ids += content_ids
|
||||
attention_mask += [1] * len(content_ids)
|
||||
|
||||
if hasattr(message, "loss_weight"):
|
||||
loss_weight = message["loss_weight"]
|
||||
else:
|
||||
loss_weight = 1 if message["role"] == "assistant" else 0
|
||||
if loss_weight == 1:
|
||||
labels += content_ids
|
||||
else:
|
||||
labels += [-100] * len(content_ids)
|
||||
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
|
||||
model_inputs.update({"position_ids": list(range(len(input_ids)))})
|
||||
model_inputs = {k: v[-max_seq_len:] for k, v in model_inputs.items()}
|
||||
return model_inputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def to_qwen3_messages(template: QwenTemplate, messages: list[dict]):
|
||||
out = []
|
||||
for m in messages:
|
||||
role = m["role"]
|
||||
content = template._extract_content(m.get("content", ""))
|
||||
if role == "assistant":
|
||||
reasoning = (m.get("reasoning_content") or "").strip()
|
||||
if reasoning:
|
||||
content = template.thinking_template.format(content=reasoning) + content
|
||||
out.append({"role": role, "content": content})
|
||||
return out
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
"Qwen/Qwen3-30B-A3B-Thinking-2507",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
test_messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "1+1等于几?"}, {"type": "text", "text": "2+2等于几?"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"reasoning_content": "这是一个简单的数学问题。1加1的结果是2。",
|
||||
"content": [{"type": "text", "text": "1+1=2"}, {"type": "text", "text": "2+2=4"}],
|
||||
},
|
||||
]
|
||||
|
||||
template = QwenTemplate()
|
||||
rendered_custom = "".join([template.render_message(m) for m in test_messages])
|
||||
|
||||
qwen3_messages = to_qwen3_messages(template, test_messages)
|
||||
rendered_hf = tok.apply_chat_template(qwen3_messages, tokenize=False, add_generation_prompt=False)
|
||||
|
||||
print("==== custom ====")
|
||||
print(rendered_custom)
|
||||
print("==== hf ====")
|
||||
print(rendered_hf)
|
||||
|
||||
assert rendered_custom.strip() == rendered_hf.strip(), "Rendered text mismatch"
|
||||
|
||||
ids_custom = tok.encode(rendered_custom, add_special_tokens=False)
|
||||
ids_hf = tok.apply_chat_template(qwen3_messages, tokenize=True, add_generation_prompt=False)
|
||||
assert ids_custom == ids_hf, f"Token ids mismatch: custom={len(ids_custom)} hf={len(ids_hf)}"
|
||||
@@ -25,12 +25,12 @@ class InitPlugin(BasePlugin):
|
||||
return super().__call__()
|
||||
|
||||
|
||||
@InitPlugin("init_on_meta").register
|
||||
@InitPlugin("init_on_meta").register()
|
||||
def init_on_meta() -> torch.device:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_rank0").register
|
||||
@InitPlugin("init_on_rank0").register()
|
||||
def init_on_rank0() -> torch.device:
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
@@ -38,6 +38,6 @@ def init_on_rank0() -> torch.device:
|
||||
return torch.device(DeviceType.META.value)
|
||||
|
||||
|
||||
@InitPlugin("init_on_default").register
|
||||
@InitPlugin("init_on_default").register()
|
||||
def init_on_default() -> torch.device:
|
||||
return DistributedInterface().current_accelerator
|
||||
return DistributedInterface().current_device
|
||||
|
||||
@@ -38,17 +38,17 @@ class BaseKernel(ABC):
|
||||
|
||||
@classmethod
|
||||
def get_kernel_id(cls) -> str:
|
||||
r"""Returns the unique identifier for the kernel."""
|
||||
"""Returns the unique identifier for the kernel."""
|
||||
return cls._kernel_id
|
||||
|
||||
@classmethod
|
||||
def get_device(cls) -> str:
|
||||
r"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
"""Returns the device type associated with the kernel (e.g., "cuda", "npu", "cpu")."""
|
||||
return cls._device
|
||||
|
||||
@classmethod
|
||||
def check_deps(cls) -> bool:
|
||||
r"""Checks if the required dependencies for the kernel are available.
|
||||
"""Checks if the required dependencies for the kernel are available.
|
||||
|
||||
Returns:
|
||||
bool: ``True`` if dependencies are met, ``False`` otherwise.
|
||||
@@ -65,7 +65,7 @@ class BaseKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the kernel optimization to the model.
|
||||
"""Applies the kernel optimization to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments, usually containing the model instance and the kernel configuration.
|
||||
|
||||
@@ -33,7 +33,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def scan_all_kernels():
|
||||
r"""Scan all kernels in the ``ops`` directory.
|
||||
"""Scan all kernels in the ``ops`` directory.
|
||||
|
||||
Scans the ``ops`` directory for all ``.py`` files and attempts to import them.
|
||||
Importing triggers the :func:`~registry.register_kernel` decorator, which automatically registers the kernels.
|
||||
@@ -77,7 +77,7 @@ default_kernels = scan_all_kernels()
|
||||
|
||||
|
||||
def get_default_kernels():
|
||||
r"""Get a list of default registered kernel IDs.
|
||||
"""Get a list of default registered kernel IDs.
|
||||
|
||||
Returns:
|
||||
list[str]: List of kernel IDs.
|
||||
@@ -86,7 +86,7 @@ def get_default_kernels():
|
||||
|
||||
|
||||
def apply_kernel(kernel_id: str, **kwargs):
|
||||
r"""Applies a specific kernel to the model.
|
||||
"""Applies a specific kernel to the model.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to apply.
|
||||
@@ -99,18 +99,19 @@ def apply_kernel(kernel_id: str, **kwargs):
|
||||
kernel = default_kernels.get(kernel_id)
|
||||
if kernel is None:
|
||||
raise ValueError(f"Kernel {kernel_id} not found")
|
||||
|
||||
kernel.apply(**kwargs)
|
||||
|
||||
|
||||
class KernelPlugin(BasePlugin):
|
||||
r"""Plugin for managing kernel optimizations."""
|
||||
"""Plugin for managing kernel optimizations."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@KernelPlugin("auto").register
|
||||
@KernelPlugin("auto").register()
|
||||
def apply_default_kernels(**kwargs):
|
||||
r"""Applies all default registered kernels to the model.
|
||||
"""Applies all default registered kernels to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to the kernel application function.
|
||||
@@ -125,8 +126,11 @@ def apply_default_kernels(**kwargs):
|
||||
use_kernels = default_kernels.keys()
|
||||
else:
|
||||
use_kernels = kwargs.get("include_kernels").split(",") # "kernel_id1,kernel_id2,kernel_id3"
|
||||
|
||||
for kernel in use_kernels:
|
||||
if kernel not in default_kernels:
|
||||
raise ValueError(f"Kernel {kernel} not found")
|
||||
|
||||
apply_kernel(kernel, **kwargs)
|
||||
|
||||
return kwargs.get("model")
|
||||
|
||||
@@ -40,11 +40,11 @@ from ...registry import register_kernel
|
||||
|
||||
|
||||
class GmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||
"""Custom autograd function for NPU Grouped Matrix Multiplication (GMM)."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, group_list):
|
||||
r"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||
"""Performs the forward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors for backward pass.
|
||||
@@ -65,7 +65,7 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
r"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||
"""Performs the backward pass of Grouped Matrix Multiplication.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
@@ -94,11 +94,11 @@ class GmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class HybridGmmFunction(torch.autograd.Function):
|
||||
r"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||
"""Custom autograd function for Hybrid Grouped Matrix Multiplication on NPU."""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, num_experts, *args):
|
||||
r"""Performs the forward pass of Hybrid GMM.
|
||||
"""Performs the forward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object to save tensors.
|
||||
@@ -124,7 +124,7 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *grad_outputs):
|
||||
r"""Performs the backward pass of Hybrid GMM.
|
||||
"""Performs the backward pass of Hybrid GMM.
|
||||
|
||||
Args:
|
||||
ctx: Context object containing saved tensors.
|
||||
@@ -176,13 +176,13 @@ class HybridGmmFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class NpuMoeFused:
|
||||
r"""Container for NPU fused MoE forward functions."""
|
||||
"""Container for NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def npu_moe_experts_forward(
|
||||
self, hidden_states: torch.Tensor, routing_weights: torch.Tensor, router_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
r"""Forward pass for MoE experts using NPU fused operations.
|
||||
"""Forward pass for MoE experts using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The MoE layer instance.
|
||||
@@ -230,11 +230,11 @@ class NpuMoeFused:
|
||||
|
||||
|
||||
class Qwen3NpuMoeFused:
|
||||
r"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
"""Container for Qwen3 NPU fused MoE forward functions."""
|
||||
|
||||
@staticmethod
|
||||
def qwen3moe_sparse_moe_block_forward(self, hidden_states: torch.Tensor):
|
||||
r"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||
"""Forward pass for Qwen3 sparse MoE block using NPU fused operations.
|
||||
|
||||
Args:
|
||||
self: The Qwen3 MoE block instance.
|
||||
@@ -298,14 +298,14 @@ if not is_transformers_version_greater_than("5.0.0"):
|
||||
|
||||
@register_kernel
|
||||
class NpuFusedMoEKernel(BaseKernel):
|
||||
r"""NPU Fused MoE Kernel implementation."""
|
||||
"""NPU Fused MoE Kernel implementation."""
|
||||
|
||||
_kernel_id = "npu_fused_moe"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> HFModel:
|
||||
r"""Applies the NPU fused MoE kernel to the model.
|
||||
"""Applies the NPU fused MoE kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
@@ -333,6 +333,7 @@ class NpuFusedMoEKernel(BaseKernel):
|
||||
|
||||
if target_moe_mapping is None:
|
||||
return model
|
||||
|
||||
for module in model.modules():
|
||||
class_name = module.__class__.__name__
|
||||
if class_name in target_moe_mapping:
|
||||
|
||||
@@ -38,7 +38,7 @@ except ImportError:
|
||||
|
||||
|
||||
def npu_swiglu_forward(self, hidden_state):
|
||||
r"""SwiGLU forward pass for NPU.
|
||||
"""SwiGLU forward pass for NPU.
|
||||
|
||||
Args:
|
||||
self: The MLP layer instance.
|
||||
@@ -53,7 +53,7 @@ def npu_swiglu_forward(self, hidden_state):
|
||||
|
||||
|
||||
def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
r"""SwiGLU forward pass for GLM4 on NPU.
|
||||
"""SwiGLU forward pass for GLM4 on NPU.
|
||||
|
||||
Args:
|
||||
self: The GLM4 MLP layer instance.
|
||||
@@ -68,7 +68,7 @@ def _npu_swiglu_glm4_forward(self, hidden_states):
|
||||
|
||||
|
||||
def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
r"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||
"""SwiGLU forward pass for Gemma3nText on NPU.
|
||||
|
||||
Args:
|
||||
self: The Gemma3nText MLP layer instance.
|
||||
@@ -88,7 +88,7 @@ def _npu_swiglu_gemma3ntext_forward(self, hidden_states):
|
||||
|
||||
@register_kernel
|
||||
class NpuSwiGluKernel(BaseKernel):
|
||||
r"""NPU Kernel for fused SwiGLU activation."""
|
||||
"""NPU Kernel for fused SwiGLU activation."""
|
||||
|
||||
# just support apply to the following module layers
|
||||
expect_modules = frozenset(
|
||||
@@ -126,7 +126,7 @@ class NpuSwiGluKernel(BaseKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Applies the NPU fused SwiGLU kernel to the model.
|
||||
"""Applies the NPU fused SwiGLU kernel to the model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
@@ -30,7 +30,7 @@ from ...registry import register_kernel
|
||||
|
||||
|
||||
def npu_rms_norm_forward(self, hidden_states):
|
||||
r"""NPU forward implementation for RMSNorm.
|
||||
"""NPU forward implementation for RMSNorm.
|
||||
|
||||
Args:
|
||||
self: RMSNorm module instance with `weight` and `variance_epsilon`.
|
||||
@@ -46,14 +46,14 @@ def npu_rms_norm_forward(self, hidden_states):
|
||||
|
||||
@register_kernel
|
||||
class NpuRMSNormKernel(BaseKernel):
|
||||
r"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
"""NPU kernel wrapper for RMSNorm that applies the replacement within a model."""
|
||||
|
||||
_kernel_id = "npu_fused_rmsnorm"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
"""Iterate the model and apply NPU-optimized forward to matched RMSNorm modules.
|
||||
|
||||
Key points:
|
||||
- Match modules whose class name contains "RMSNorm" (case-insensitive).
|
||||
@@ -78,6 +78,7 @@ class NpuRMSNormKernel(BaseKernel):
|
||||
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||
|
||||
rms_norm_pattern = re.compile("RMSNorm", re.IGNORECASE)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
|
||||
@@ -40,7 +40,7 @@ except ImportError:
|
||||
|
||||
|
||||
def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
r"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||
"""Applies Rotary Position Embedding to the query and key tensors using NPU optimization.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
@@ -61,7 +61,7 @@ def _apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||
|
||||
|
||||
def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||
r"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||
"""Applies Rotary Position Embedding with multimodal sections (Qwen2-VL) on NPU.
|
||||
|
||||
Args:
|
||||
q (Tensor): Query tensor.
|
||||
@@ -89,14 +89,14 @@ def _apply_multimodal_rotary_pos_emb_qwen25_vl(q, k, cos, sin, mrope_section, un
|
||||
|
||||
@register_kernel
|
||||
class NpuRoPEKernel(BaseKernel):
|
||||
r"""NPU Kernel for Rotary Position Embedding."""
|
||||
"""NPU Kernel for Rotary Position Embedding."""
|
||||
|
||||
_kernel_id = "npu_fused_rope"
|
||||
_device = DeviceType.NPU
|
||||
|
||||
@classmethod
|
||||
def apply(cls, **kwargs) -> "HFModel":
|
||||
r"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
identifies the module where they are defined, and replaces the original
|
||||
@@ -115,9 +115,11 @@ class NpuRoPEKernel(BaseKernel):
|
||||
"""
|
||||
if not cls.check_deps():
|
||||
raise RuntimeError(f"torch_npu is not available but {cls.__name__} was called.")
|
||||
|
||||
model = kwargs.get("model", None)
|
||||
if model is None:
|
||||
raise ValueError(f"HFModel instance is required for {cls.__name__}.")
|
||||
|
||||
_modules = set()
|
||||
for module in model.modules():
|
||||
if "Attention" in module.__class__.__name__:
|
||||
@@ -143,4 +145,5 @@ class NpuRoPEKernel(BaseKernel):
|
||||
_modules.add(module_name)
|
||||
except Exception as e:
|
||||
logger.warning_rank0_once(f"Failed to apply RoPE kernel to module {module_name}: {e}")
|
||||
|
||||
return model
|
||||
|
||||
@@ -30,7 +30,7 @@ __all__ = ["Registry", "register_kernel"]
|
||||
|
||||
|
||||
class Registry:
|
||||
r"""Registry for managing kernel implementations.
|
||||
"""Registry for managing kernel implementations.
|
||||
|
||||
Storage structure: ``{ "kernel_id": Class }``
|
||||
"""
|
||||
@@ -38,8 +38,8 @@ class Registry:
|
||||
_kernels: dict[str, type[BaseKernel]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, kernel_cls: type[BaseKernel]):
|
||||
r"""Decorator to register a kernel class.
|
||||
def register(cls, kernel_cls: type[BaseKernel]) -> type[BaseKernel] | None:
|
||||
"""Decorator to register a kernel class.
|
||||
|
||||
The class must inherit from :class:`BaseKernel` and specify ``_kernel_id`` and ``_device`` attributes.
|
||||
|
||||
@@ -47,7 +47,7 @@ class Registry:
|
||||
kernel_cls (type[BaseKernel]): The kernel class to register.
|
||||
|
||||
Returns:
|
||||
type[BaseKernel]: The registered kernel class.
|
||||
type[BaseKernel] | None: The registered kernel class if the device type matches the current accelerator
|
||||
|
||||
Raises:
|
||||
TypeError: If the class does not inherit from :class:`BaseKernel`.
|
||||
@@ -55,6 +55,7 @@ class Registry:
|
||||
"""
|
||||
if not issubclass(kernel_cls, BaseKernel):
|
||||
raise TypeError(f"Class {kernel_cls} must inherit from BaseKernel")
|
||||
|
||||
kernel_id = kernel_cls.get_kernel_id()
|
||||
device = kernel_cls.get_device()
|
||||
|
||||
@@ -73,7 +74,7 @@ class Registry:
|
||||
|
||||
@classmethod
|
||||
def get(cls, kernel_id: str) -> Optional[type[BaseKernel]]:
|
||||
r"""Retrieves a registered kernel implementation by its ID.
|
||||
"""Retrieves a registered kernel implementation by its ID.
|
||||
|
||||
Args:
|
||||
kernel_id (str): The ID of the kernel to retrieve.
|
||||
@@ -85,7 +86,7 @@ class Registry:
|
||||
|
||||
@classmethod
|
||||
def get_registered_kernels(cls) -> dict[str, type[BaseKernel]]:
|
||||
r"""Returns a dictionary of all registered kernels.
|
||||
"""Returns a dictionary of all registered kernels.
|
||||
|
||||
Returns:
|
||||
dict[str, type[BaseKernel]]: Dictionary mapping kernel IDs to kernel classes.
|
||||
|
||||
@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
|
||||
return super().__call__(model, config)
|
||||
|
||||
|
||||
@PeftPlugin("lora").register
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("freeze").register
|
||||
@PeftPlugin("freeze").register()
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
||||
raise NotImplementedError()
|
||||
|
||||
36
src/llamafactory/v1/plugins/model_plugins/rendering.py
Normal file
36
src/llamafactory/v1/plugins/model_plugins/rendering.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
|
||||
|
||||
class RenderingPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
|
||||
@RenderingPlugin("qwen").register("render_messages")
|
||||
def render_qwen_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@RenderingPlugin("qwen").register("parse_message")
|
||||
def parse_qwen_message(generated_text: str) -> Message:
|
||||
raise NotImplementedError()
|
||||
@@ -12,10 +12,64 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from threading import Thread
|
||||
|
||||
from ..config import InputArgument, SampleBackend, get_args
|
||||
from ..config import InputArgument, ModelArguments, SampleArguments, SampleBackend, get_args
|
||||
from ..core.base_sampler import BaseSampler
|
||||
from ..core.model_loader import ModelLoader
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_engine import ModelEngine
|
||||
from ..core.utils.rendering import Renderer
|
||||
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||
|
||||
|
||||
class SyncSampler(BaseSampler):
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
def _start_background_loop(loop: asyncio.AbstractEventLoop) -> None:
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_forever()
|
||||
|
||||
super().__init__(args, model_args, model, renderer)
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def generate(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||
"""Generate tokens synchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
generator = super().generate(messages, tools)
|
||||
while True:
|
||||
try:
|
||||
token = asyncio.run_coroutine_threadsafe(generator.__anext__(), self._loop).result()
|
||||
yield token
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples synchronously.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
return asyncio.run_coroutine_threadsafe(super().batch_infer(dataset), self._loop).result()
|
||||
|
||||
|
||||
def run_chat(args: InputArgument = None):
|
||||
@@ -23,12 +77,48 @@ def run_chat(args: InputArgument = None):
|
||||
if sample_args.sample_backend != SampleBackend.HF:
|
||||
model_args.init_plugin = {"name": "init_on_meta"}
|
||||
|
||||
model_loader = ModelLoader(model_args)
|
||||
sampler = BaseSampler(sample_args, model_args, model_loader.model, model_loader.processor)
|
||||
model_engine = ModelEngine(model_args)
|
||||
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||
if data_args.dataset is not None:
|
||||
sampler.batch_infer()
|
||||
dataset = DataEngine(data_args)
|
||||
sampler.batch_infer(dataset)
|
||||
else:
|
||||
sampler.generate()
|
||||
if os.name != "nt":
|
||||
try:
|
||||
import readline # noqa: F401
|
||||
except ImportError:
|
||||
print("Install `readline` for a better experience.")
|
||||
|
||||
messages = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
while True:
|
||||
try:
|
||||
query = input("\nUser: ")
|
||||
except UnicodeDecodeError:
|
||||
print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
if query.strip() == "exit":
|
||||
break
|
||||
|
||||
if query.strip() == "clear":
|
||||
messages = []
|
||||
print("History has been removed.")
|
||||
continue
|
||||
|
||||
messages.append({"role": "user", "content": [{"type": "text", "value": query}]})
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
response = ""
|
||||
for new_text in sampler.generate(messages):
|
||||
print(new_text, end="", flush=True)
|
||||
response += new_text
|
||||
|
||||
print()
|
||||
messages.append(model_engine.renderer.parse_message(response))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,7 +17,7 @@ from ..accelerator.interface import DistributedInterface
|
||||
from ..config.arg_parser import get_args
|
||||
from ..core.base_trainer import BaseTrainer
|
||||
from ..core.data_engine import DataEngine
|
||||
from ..core.model_loader import ModelLoader
|
||||
from ..core.model_engine import ModelEngine
|
||||
|
||||
|
||||
class SFTTrainer(BaseTrainer):
|
||||
@@ -28,11 +28,11 @@ def run_sft(user_args):
|
||||
model_args, data_args, training_args, _ = get_args(user_args)
|
||||
DistributedInterface(training_args.dist_config)
|
||||
data_engine = DataEngine(data_args)
|
||||
model_loader = ModelLoader(model_args)
|
||||
model_engine = ModelEngine(model_args)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model_loader.model,
|
||||
processor=model_loader.processor,
|
||||
model=model_engine.model,
|
||||
processor=model_engine.processor,
|
||||
dataset=data_engine,
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
@@ -11,3 +11,5 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
29
src/llamafactory/v1/utils/helper.py
Normal file
29
src/llamafactory/v1/utils/helper.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .types import Processor
|
||||
|
||||
|
||||
def get_tokenizer(processor: Processor) -> PreTrainedTokenizer:
|
||||
"""Get tokenizer from processor.
|
||||
|
||||
Args:
|
||||
processor: Processor.
|
||||
|
||||
Returns:
|
||||
Tokenizer.
|
||||
"""
|
||||
return processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||
@@ -54,7 +54,7 @@ def _get_default_logging_level() -> "logging._Level":
|
||||
|
||||
|
||||
def _get_library_name() -> str:
|
||||
return __name__.split(".")[0]
|
||||
return ".".join(__name__.split(".")[:2]) # llamafactory.v1
|
||||
|
||||
|
||||
def _get_library_root_logger() -> "_Logger":
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
from . import logging
|
||||
@@ -27,7 +28,7 @@ class BasePlugin:
|
||||
A plugin is a callable object that can be registered and called by name.
|
||||
"""
|
||||
|
||||
_registry: dict[str, Callable] = {}
|
||||
_registry: dict[str, dict[str, Callable]] = defaultdict(dict)
|
||||
|
||||
def __init__(self, name: str | None = None):
|
||||
"""Initialize the plugin with a name.
|
||||
@@ -37,8 +38,7 @@ class BasePlugin:
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
@property
|
||||
def register(self):
|
||||
def register(self, method_name: str = "__call__"):
|
||||
"""Decorator to register a function as a plugin.
|
||||
|
||||
Example usage:
|
||||
@@ -46,16 +46,21 @@ class BasePlugin:
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
|
||||
@PrintPlugin("hello").register("again")
|
||||
def print_hello_again():
|
||||
print("Hello world! Again.")
|
||||
```
|
||||
"""
|
||||
if self.name is None:
|
||||
raise ValueError("Plugin name is not specified.")
|
||||
raise ValueError("Plugin name should be specified.")
|
||||
|
||||
if self.name in self._registry:
|
||||
logger.warning_rank0_once(f"Plugin {self.name} is already registered.")
|
||||
if method_name in self._registry[self.name]:
|
||||
logger.warning_rank0_once(f"Method {method_name} of plugin {self.name} is already registered.")
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
self._registry[self.name] = func
|
||||
self._registry[self.name][method_name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -68,10 +73,23 @@ class BasePlugin:
|
||||
PrintPlugin("hello")()
|
||||
```
|
||||
"""
|
||||
if self.name not in self._registry:
|
||||
raise ValueError(f"Plugin {self.name} is not registered.")
|
||||
if "__call__" not in self._registry[self.name]:
|
||||
raise ValueError(f"Method __call__ of plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name](*args, **kwargs)
|
||||
return self._registry[self.name]["__call__"](*args, **kwargs)
|
||||
|
||||
def __getattr__(self, method_name: str):
|
||||
"""Get the registered function with the given name.
|
||||
|
||||
Example usage:
|
||||
```python
|
||||
PrintPlugin("hello").again()
|
||||
```
|
||||
"""
|
||||
if method_name not in self._registry[self.name]:
|
||||
raise ValueError(f"Method {method_name} of plugin {self.name} is not registered.")
|
||||
|
||||
return self._registry[self.name][method_name]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -82,8 +100,13 @@ if __name__ == "__main__":
|
||||
class PrintPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
@PrintPlugin("hello").register
|
||||
@PrintPlugin("hello").register()
|
||||
def print_hello():
|
||||
print("Hello world!")
|
||||
|
||||
@PrintPlugin("hello").register("again")
|
||||
def print_hello_again():
|
||||
print("Hello world! Again.")
|
||||
|
||||
PrintPlugin("hello")()
|
||||
PrintPlugin("hello").again()
|
||||
|
||||
@@ -84,27 +84,59 @@ class DistributedConfig(TypedDict, total=False):
|
||||
|
||||
|
||||
class Content(TypedDict):
|
||||
type: Literal["text", "reasoning", "tools", "tool_calls", "image_url"]
|
||||
type: Literal["text", "reasoning", "tool_call", "image_url"]
|
||||
"""Type of the content."""
|
||||
value: str
|
||||
"""Value of the content."""
|
||||
|
||||
|
||||
class Message(TypedDict):
|
||||
role: Literal["system", "user", "assistant", "tool"]
|
||||
"""Role of the message."""
|
||||
content: list[Content]
|
||||
loss_weight: float
|
||||
"""Content of the message."""
|
||||
loss_weight: NotRequired[float]
|
||||
"""Loss weight for this message, default to 1.0. Required in training."""
|
||||
|
||||
|
||||
class SFTSample(TypedDict):
|
||||
messages: list[Message]
|
||||
"""Messages in the sample."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, including tools, kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
|
||||
class DPOSample(TypedDict):
|
||||
chosen_messages: list[Message]
|
||||
"""Chosen messages in the sample."""
|
||||
rejected_messages: list[Message]
|
||||
"""Rejected messages in the sample."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, including tools, kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
|
||||
Sample = Union[SFTSample, DPOSample]
|
||||
|
||||
|
||||
class ToolCall(TypedDict):
|
||||
name: str
|
||||
"""Function name."""
|
||||
arguments: str
|
||||
"""Function arguments."""
|
||||
|
||||
|
||||
class ModelInput(TypedDict, total=False):
|
||||
input_ids: list[int]
|
||||
"""Input ids for the model."""
|
||||
attention_mask: list[int]
|
||||
"""Attention mask for the model."""
|
||||
labels: list[int]
|
||||
"""Labels for the model."""
|
||||
loss_weights: list[float]
|
||||
"""Loss weight for each token, default to 1.0."""
|
||||
position_ids: NotRequired[list[int] | list[list[int]]]
|
||||
"""Position ids for the model (optional)."""
|
||||
|
||||
@@ -18,7 +18,7 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -73,7 +73,7 @@ def _handle_slow_tests(items: list[Item]):
|
||||
item.add_marker(skip_slow)
|
||||
|
||||
|
||||
def _get_visible_devices_env() -> Optional[str]:
|
||||
def _get_visible_devices_env() -> str | None:
|
||||
"""Return device visibility env var name."""
|
||||
if CURRENT_DEVICE == "cuda":
|
||||
return "CUDA_VISIBLE_DEVICES"
|
||||
@@ -149,6 +149,14 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
devices_str = ",".join(str(i) for i in range(required))
|
||||
|
||||
monkeypatch.setenv(env_key, devices_str)
|
||||
|
||||
# add project root dir to path for mp run
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
os.environ["PYTHONPATH"] = project_root + os.pathsep + os.environ.get("PYTHONPATH", "")
|
||||
|
||||
else: # non-distributed test
|
||||
if old_value:
|
||||
visible_devices = [v for v in old_value.split(",") if v != ""]
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.4.105
|
||||
0.9.5.101
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
|
||||
|
||||
Tests the 4 scenarios:
|
||||
a) non pack + non dynamic.
|
||||
b) non pack + dynamic.
|
||||
c) pack + non dynamic.
|
||||
d) pack + dynamic.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader as TorchDataLoader
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.v1.config.data_args import DataArguments
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.trainer_utils.data_collator import (
|
||||
DefaultCollator,
|
||||
)
|
||||
from llamafactory.v1.core.trainer_utils.data_loader import DataLoader
|
||||
from llamafactory.v1.plugins.data_plugins.template import QwenTemplate
|
||||
from llamafactory.v1.utils.batching_queue import TextBatchingQueue
|
||||
|
||||
|
||||
class TensorDataset(Dataset):
|
||||
"""Wrapper dataset that converts DataEngine samples to tensor format."""
|
||||
|
||||
def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
|
||||
self.data_engine = data_engine
|
||||
self.processor = processor
|
||||
self.template = template
|
||||
self.max_samples = max_samples or len(data_engine)
|
||||
self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||
|
||||
def __len__(self):
|
||||
return min(self.max_samples, len(self.data_engine))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Get sample from DataEngine
|
||||
sample = self.data_engine[idx]
|
||||
|
||||
# Extract messages from sample
|
||||
# DataEngine returns samples with format like {"messages": [...], ...}
|
||||
# For llamafactory/v1-sft-demo, the format should have "messages" field
|
||||
messages = None
|
||||
if "messages" in sample:
|
||||
messages = sample["messages"]
|
||||
elif "conversations" in sample:
|
||||
messages = sample["conversations"]
|
||||
elif "conversation" in sample:
|
||||
messages = sample["conversation"]
|
||||
else:
|
||||
# Try to find message-like fields (skip _dataset_name)
|
||||
for key, value in sample.items():
|
||||
if key.startswith("_"):
|
||||
continue
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
# Check if it looks like a message list
|
||||
if isinstance(value[0], dict) and "role" in value[0]:
|
||||
messages = value
|
||||
break
|
||||
|
||||
if messages is None:
|
||||
raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
|
||||
|
||||
# Encode messages using template
|
||||
encoded = self.template.encode_messages(self.tokenizer, messages)
|
||||
|
||||
# Convert to tensors
|
||||
return {
|
||||
"input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||
"attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||||
"labels": torch.tensor(encoded["labels"], dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
|
||||
"""Create a real dataset using DataEngine."""
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args)
|
||||
|
||||
# Create processor and template
|
||||
processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
template = QwenTemplate()
|
||||
|
||||
# Create tensor dataset
|
||||
raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
|
||||
|
||||
# Create torch DataLoader
|
||||
torch_dataloader = TorchDataLoader(
|
||||
raw_data_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
|
||||
return torch_dataloader, processor, template
|
||||
|
||||
|
||||
class TestDataLoaderNonPackNonDynamic:
|
||||
"""Test case a) non pack + non dynamic."""
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test DataLoader without packing and without dynamic batching."""
|
||||
# Create real dataset
|
||||
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||
|
||||
# Create collator (non-packing)
|
||||
collator = DefaultCollator(processor=processor, template=template)
|
||||
|
||||
# Create DataLoader without batching_queue (non-dynamic)
|
||||
data_loader = DataLoader(
|
||||
dataloader=torch_dataloader,
|
||||
collate_fn=collator,
|
||||
num_micro_batch=1,
|
||||
batching_queue=None,
|
||||
)
|
||||
|
||||
# Iterate and check results
|
||||
batches = list(iter(data_loader))
|
||||
assert len(batches) > 0
|
||||
|
||||
# Check first batch
|
||||
one_batch = batches[0]
|
||||
micro_batches = one_batch[0]
|
||||
assert "input_ids" in micro_batches
|
||||
assert "attention_mask" in micro_batches
|
||||
assert "labels" in micro_batches
|
||||
assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
|
||||
assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
|
||||
|
||||
|
||||
class TestDataLoaderNonPackDynamic:
|
||||
"""Test case b) non pack + dynamic."""
|
||||
|
||||
def test_basic_functionality(self):
|
||||
"""Test DataLoader without packing but with dynamic batching."""
|
||||
# Create real dataset
|
||||
torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||
collator = DefaultCollator(processor=processor, template=template)
|
||||
|
||||
# Create batching queue for dynamic batching
|
||||
batching_queue = TextBatchingQueue(
|
||||
token_micro_bsz=120,
|
||||
buffer_size=8,
|
||||
)
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataloader=torch_dataloader,
|
||||
collate_fn=collator,
|
||||
num_micro_batch=4,
|
||||
batching_queue=batching_queue,
|
||||
)
|
||||
|
||||
# Iterate and check
|
||||
batches = list(iter(data_loader))
|
||||
micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
|
||||
assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
|
||||
assert len(batches) > 0
|
||||
@@ -15,18 +15,18 @@
|
||||
import torch
|
||||
|
||||
from llamafactory.v1.config.model_args import ModelArguments, PluginConfig
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
|
||||
|
||||
def test_tiny_qwen():
|
||||
from transformers import Qwen2Config, Qwen2ForCausalLM, Qwen2TokenizerFast
|
||||
|
||||
model_args = ModelArguments(model="llamafactory/tiny-random-qwen2.5")
|
||||
model_loader = ModelLoader(model_args)
|
||||
assert isinstance(model_loader.processor, Qwen2TokenizerFast)
|
||||
assert isinstance(model_loader.model.config, Qwen2Config)
|
||||
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
||||
assert model_loader.model.dtype == torch.bfloat16
|
||||
model_engine = ModelEngine(model_args)
|
||||
assert isinstance(model_engine.processor, Qwen2TokenizerFast)
|
||||
assert isinstance(model_engine.model_config, Qwen2Config)
|
||||
assert isinstance(model_engine.model, Qwen2ForCausalLM)
|
||||
assert model_engine.model.dtype == torch.bfloat16
|
||||
|
||||
|
||||
def test_tiny_qwen_with_kernel_plugin():
|
||||
@@ -37,13 +37,14 @@ def test_tiny_qwen_with_kernel_plugin():
|
||||
model_args = ModelArguments(
|
||||
model="llamafactory/tiny-random-qwen2.5", kernel_config=PluginConfig(name="auto", include_kernels="auto")
|
||||
)
|
||||
model_loader = ModelLoader(model_args)
|
||||
model_engine = ModelEngine(model_args)
|
||||
# test enable apply kernel plugin
|
||||
if hasattr(torch, "npu"):
|
||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
|
||||
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ == npu_rms_norm_forward.__code__
|
||||
else:
|
||||
assert model_loader.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
assert isinstance(model_loader.model, Qwen2ForCausalLM)
|
||||
assert model_engine.model.model.layers[0].input_layernorm.forward.__code__ != npu_rms_norm_forward.__code__
|
||||
|
||||
assert isinstance(model_engine.model, Qwen2ForCausalLM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
171
tests_v1/core/utils/test_data_loader.py
Normal file
171
tests_v1/core/utils/test_data_loader.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Integration tests for DataLoader with different combinations of packing and dynamic batching.
|
||||
|
||||
Tests the 4 scenarios:
|
||||
a) non pack + non dynamic.
|
||||
b) non pack + dynamic.
|
||||
c) pack + non dynamic.
|
||||
d) pack + dynamic.
|
||||
"""
|
||||
|
||||
# import torch
|
||||
# from torch.utils.data import DataLoader as TorchDataLoader
|
||||
# from torch.utils.data import Dataset
|
||||
# from transformers import AutoTokenizer
|
||||
|
||||
# from llamafactory.v1.config.data_args import DataArguments
|
||||
# from llamafactory.v1.core.data_engine import DataEngine
|
||||
# from llamafactory.v1.core.utils.data_collator import DefaultCollator
|
||||
# from llamafactory.v1.core.utils.data_loader import DataLoader
|
||||
# from llamafactory.v1.plugins.data_plugins.rendering import QwenTemplate
|
||||
# from llamafactory.v1.utils.batching_queue import TextBatchingQueue
|
||||
|
||||
|
||||
# class TensorDataset(Dataset):
|
||||
# """Wrapper dataset that converts DataEngine samples to tensor format."""
|
||||
|
||||
# def __init__(self, data_engine: DataEngine, processor, template, max_samples: int = None):
|
||||
# self.data_engine = data_engine
|
||||
# self.processor = processor
|
||||
# self.template = template
|
||||
# self.max_samples = max_samples or len(data_engine)
|
||||
# self.tokenizer = processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
||||
|
||||
# def __len__(self):
|
||||
# return min(self.max_samples, len(self.data_engine))
|
||||
|
||||
# def __getitem__(self, idx):
|
||||
# # Get sample from DataEngine
|
||||
# sample = self.data_engine[idx]
|
||||
|
||||
# # Extract messages from sample
|
||||
# # DataEngine returns samples with format like {"messages": [...], ...}
|
||||
# # For llamafactory/v1-sft-demo, the format should have "messages" field
|
||||
# messages = None
|
||||
# if "messages" in sample:
|
||||
# messages = sample["messages"]
|
||||
# elif "conversations" in sample:
|
||||
# messages = sample["conversations"]
|
||||
# elif "conversation" in sample:
|
||||
# messages = sample["conversation"]
|
||||
# else:
|
||||
# # Try to find message-like fields (skip _dataset_name)
|
||||
# for key, value in sample.items():
|
||||
# if key.startswith("_"):
|
||||
# continue
|
||||
# if isinstance(value, list) and len(value) > 0:
|
||||
# # Check if it looks like a message list
|
||||
# if isinstance(value[0], dict) and "role" in value[0]:
|
||||
# messages = value
|
||||
# break
|
||||
|
||||
# if messages is None:
|
||||
# raise ValueError(f"Could not find messages in sample: {list(sample.keys())}")
|
||||
|
||||
# # Encode messages using template
|
||||
# encoded = self.template.encode_messages(self.tokenizer, messages)
|
||||
|
||||
# # Convert to tensors
|
||||
# return {
|
||||
# "input_ids": torch.tensor(encoded["input_ids"], dtype=torch.long),
|
||||
# "attention_mask": torch.tensor(encoded["attention_mask"], dtype=torch.long),
|
||||
# "labels": torch.tensor(encoded["labels"], dtype=torch.long),
|
||||
# }
|
||||
|
||||
|
||||
# def create_real_dataset(max_samples: int = 20, batch_size: int = 4):
|
||||
# """Create a real dataset using DataEngine."""
|
||||
# data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
# data_engine = DataEngine(data_args)
|
||||
|
||||
# # Create processor and template
|
||||
# processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||
# template = QwenTemplate()
|
||||
|
||||
# # Create tensor dataset
|
||||
# raw_data_dataset = TensorDataset(data_engine, processor, template, max_samples=max_samples)
|
||||
|
||||
# # Create torch DataLoader
|
||||
# torch_dataloader = TorchDataLoader(
|
||||
# raw_data_dataset,
|
||||
# batch_size=batch_size,
|
||||
# shuffle=False,
|
||||
# collate_fn=lambda x: x,
|
||||
# )
|
||||
|
||||
# return torch_dataloader, processor, template
|
||||
|
||||
|
||||
# class TestDataLoaderNonPackNonDynamic:
|
||||
# """Test case a) non pack + non dynamic."""
|
||||
|
||||
# def test_basic_functionality(self):
|
||||
# """Test DataLoader without packing and without dynamic batching."""
|
||||
# # Create real dataset
|
||||
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||
|
||||
# # Create collator (non-packing)
|
||||
# collator = DefaultCollator(processor=processor, template=template)
|
||||
|
||||
# # Create DataLoader without batching_queue (non-dynamic)
|
||||
# data_loader = DataLoader(
|
||||
# dataloader=torch_dataloader,
|
||||
# collate_fn=collator,
|
||||
# num_micro_batch=1,
|
||||
# batching_queue=None,
|
||||
# )
|
||||
|
||||
# # Iterate and check results
|
||||
# batches = list(iter(data_loader))
|
||||
# assert len(batches) > 0
|
||||
|
||||
# # Check first batch
|
||||
# one_batch = batches[0]
|
||||
# micro_batches = one_batch[0]
|
||||
# assert "input_ids" in micro_batches
|
||||
# assert "attention_mask" in micro_batches
|
||||
# assert "labels" in micro_batches
|
||||
# assert micro_batches["input_ids"].shape[0] == 1 # batch_size=1
|
||||
# assert micro_batches["input_ids"].ndim == 2 # [batch_size, seq_len]
|
||||
|
||||
|
||||
# class TestDataLoaderNonPackDynamic:
|
||||
# """Test case b) non pack + dynamic."""
|
||||
|
||||
# def test_basic_functionality(self):
|
||||
# """Test DataLoader without packing but with dynamic batching."""
|
||||
# # Create real dataset
|
||||
# torch_dataloader, processor, template = create_real_dataset(max_samples=80, batch_size=8)
|
||||
# collator = DefaultCollator(processor=processor, template=template)
|
||||
|
||||
# # Create batching queue for dynamic batching
|
||||
# batching_queue = TextBatchingQueue(
|
||||
# token_micro_bsz=120,
|
||||
# buffer_size=8,
|
||||
# )
|
||||
|
||||
# data_loader = DataLoader(
|
||||
# dataloader=torch_dataloader,
|
||||
# collate_fn=collator,
|
||||
# num_micro_batch=4,
|
||||
# batching_queue=batching_queue,
|
||||
# )
|
||||
|
||||
# # Iterate and check
|
||||
# batches = list(iter(data_loader))
|
||||
# micro_batch_tokens_first = [micro_batch["attention_mask"].sum() for micro_batch in batches[0]]
|
||||
# assert all(num_tokens <= 120 for num_tokens in micro_batch_tokens_first)
|
||||
# assert len(batches) > 0
|
||||
65
tests_v1/core/utils/test_rendering.py
Normal file
65
tests_v1/core/utils/test_rendering.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.v1.core.utils.rendering import Renderer
|
||||
from llamafactory.v1.utils.types import Processor
|
||||
|
||||
|
||||
HF_MESSAGES = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is LLM?"},
|
||||
{"role": "assistant", "content": "LLM stands for Large Language Model."},
|
||||
]
|
||||
V1_MESSAGES = [
|
||||
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
|
||||
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
|
||||
]
|
||||
|
||||
|
||||
def test_chatml_rendering():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
|
||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=True)
|
||||
v1_inputs = renderer.render_messages(V1_MESSAGES[:-1], is_generate=True)
|
||||
assert v1_inputs["input_ids"] == hf_inputs
|
||||
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||
|
||||
hf_inputs_part = tokenizer.apply_chat_template(HF_MESSAGES[:-1], add_generation_prompt=False)
|
||||
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES, add_generation_prompt=False)
|
||||
v1_inputs_full = renderer.render_messages(V1_MESSAGES, is_generate=False)
|
||||
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
|
||||
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
|
||||
len(hf_inputs_full) - len(hf_inputs_part)
|
||||
)
|
||||
|
||||
|
||||
def test_chatml_parse():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
generated_text = "LLM stands for Large Language Model."
|
||||
parsed_message = renderer.parse_message(generated_text)
|
||||
assert parsed_message == V1_MESSAGES[-1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatml_rendering()
|
||||
test_chatml_parse()
|
||||
@@ -54,7 +54,7 @@ def test_sharegpt_converter():
|
||||
"conversations": [
|
||||
{"from": "system", "value": "System"},
|
||||
{"from": "human", "value": "User"},
|
||||
{"from": "function_call", "value": "Tool"},
|
||||
{"from": "function_call", "value": "1"},
|
||||
{"from": "observation", "value": "Observation"},
|
||||
{"from": "gpt", "value": "Assistant"},
|
||||
]
|
||||
@@ -63,7 +63,7 @@ def test_sharegpt_converter():
|
||||
"messages": [
|
||||
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
|
||||
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
|
||||
{"content": [{"type": "tool_calls", "value": "Tool"}], "loss_weight": 1.0, "role": "assistant"},
|
||||
{"content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0, "role": "assistant"},
|
||||
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
|
||||
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
|
||||
]
|
||||
|
||||
@@ -12,11 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
|
||||
|
||||
def test_init_on_meta():
|
||||
@@ -26,11 +25,10 @@ def test_init_on_meta():
|
||||
init_config={"name": "init_on_meta"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
assert model_loader.model.device.type == "meta"
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
assert model_engine.model.device.type == "meta"
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cuda", "npu"])
|
||||
def test_init_on_rank0():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
@@ -38,11 +36,11 @@ def test_init_on_rank0():
|
||||
init_config={"name": "init_on_rank0"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
assert model_loader.model.device.type == "cpu"
|
||||
assert model_engine.model.device.type == "cpu"
|
||||
else:
|
||||
assert model_loader.model.device.type == "meta"
|
||||
assert model_engine.model.device.type == "meta"
|
||||
|
||||
|
||||
def test_init_on_default():
|
||||
@@ -52,5 +50,5 @@ def test_init_on_default():
|
||||
init_config={"name": "init_on_default"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
assert model_engine.model.device == DistributedInterface().current_device
|
||||
|
||||
41
tests_v1/sampler/test_cli_sampler.py
Normal file
41
tests_v1/sampler/test_cli_sampler.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.v1.config import ModelArguments, SampleArguments
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.samplers.cli_sampler import SyncSampler
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cuda", "npu"])
|
||||
def test_sync_sampler():
|
||||
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507")
|
||||
sample_args = SampleArguments()
|
||||
model_engine = ModelEngine(model_args)
|
||||
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||
messages = [{"role": "user", "content": [{"type": "text", "value": "Say 'This is a test.'"}]}]
|
||||
response = ""
|
||||
for new_text in sampler.generate(messages):
|
||||
response += new_text
|
||||
|
||||
print(response)
|
||||
assert model_engine.renderer.parse_message(response) == {
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "value": "This is a test."}],
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sync_sampler()
|
||||
Reference in New Issue
Block a user