[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -12,10 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from threading import Thread
import torch
from transformers import TextIteratorStreamer
from ..accelerator.interface import DistributedInterface
from ..config import ModelArguments, SampleArguments, SampleBackend
from ..utils.types import HFModel, Processor, TorchDataset
from ..utils.helper import get_tokenizer
from ..utils.types import HFModel, Message, Sample, TorchDataset
from .utils.rendering import Renderer
class BaseEngine(ABC):
@@ -24,8 +34,8 @@ class BaseEngine(ABC):
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel = None,
processor: Processor = None,
model: HFModel,
renderer: Renderer,
) -> None:
"""Initialize the engine.
@@ -33,17 +43,34 @@ class BaseEngine(ABC):
args: Sample arguments.
model_args: Model arguments.
model: Model.
processor: Processor.
renderer: Renderer.
"""
...
@abstractmethod
async def generate(self, messages):
pass
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
...
@abstractmethod
async def batch_infer(self, data: TorchDataset) -> None:
pass
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
...
class HuggingFaceEngine(BaseEngine):
@@ -52,26 +79,103 @@ class HuggingFaceEngine(BaseEngine):
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
renderer: Renderer,
) -> None:
self.args = args
self.model_args = model_args
self.model = model
self.renderer = renderer
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
@torch.inference_mode()
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
streamer = TextIteratorStreamer(
tokenizer=get_tokenizer(self.renderer.processor),
skip_prompt=True,
skip_special_tokens=True, # TODO: configurable
)
device = DistributedInterface().current_device
kwargs = {
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
"max_new_tokens": self.args.max_new_tokens,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
thread.start()
def stream():
try:
return streamer.__next__()
except StopIteration:
raise StopAsyncIteration()
return stream
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
async with self.semaphore:
response = self.get_response(messages, tools)
while True:
try:
yield await asyncio.to_thread(response)
except StopAsyncIteration:
break
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
raise NotImplementedError("Batch infer is not implemented.")
class BaseSampler:
"""Base sampler.
Args:
args: Sample arguments.
model_args: Model arguments.
model: Model.
renderer: Renderer.
"""
def __init__(
self,
args: SampleArguments,
model_args: ModelArguments,
model: HFModel,
processor: Processor,
renderer: Renderer,
) -> None:
if args.sample_backend == SampleBackend.HF:
self.engine = HuggingFaceEngine(args, model_args, model, processor)
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
else:
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
async def generate(self, messages):
return await self.engine.generate(messages)
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
"""Generate tokens asynchronously.
async def batch_infer(self, data: TorchDataset) -> None:
return await self.engine.batch_infer(data)
Args:
messages: List of messages.
tools: Tools string.
Yields:
Generated tokens.
"""
async for token in self.engine.generate(messages, tools):
yield token
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
"""Batch infer samples.
Args:
dataset: Torch dataset.
Returns:
List of samples.
"""
return await self.engine.batch_infer(dataset)

View File

@@ -14,15 +14,23 @@
"""The definition of data engine.
Init Data engine:
How to use:
data_engine = DataEngine(data_args)
data_engine[i]: Get the sample via index.
Init workflow:
1. Parse dataset info from arguments.
2. Load datasets according to dataset info.
3. Build data index (and reweight samples if necessary).
Get Data Sample:
Get data sample:
1. Get sample from data index.
2. Convert sample to standard format.
3. Return sample.
Note:
1. The data engine is equivalent to the torch dataset.
2. The data engine is agnostic to the model used.
"""
import os
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
size = self.dataset_infos[dataset_name].get("size")
weight = self.dataset_infos[dataset_name].get("weight")
if size or weight: # data index plugin
from ..plugins.data_plugins.loader import DataIndexPlugin
if size or weight:
from ..plugins.data_plugins.loader import adjust_data_index
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
data_index = adjust_data_index(data_index, size, weight)
self.data_index.extend(data_index)
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
dataset_name, sample_index = self.data_index[index]
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
else: # data selector plugin
from ..plugins.data_plugins.loader import DataSelectorPlugin
from ..plugins.data_plugins.loader import select_data_sample
selected_index = DataSelectorPlugin().select(self.data_index, index)
selected_index = select_data_sample(self.data_index, index)
if isinstance(selected_index, list):
return [
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)

View File

@@ -12,16 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The definition of model loader.
"""The definition of model engine.
How to use:
model_loader = ModelLoader(model_args, is_trainable=True)
model_loader.processor: Get the tokenizer or multi-modal processor.
model_loader.model_config: Get the model configuration.
model_loader.model: Get the HF model.
model_engine = ModelEngine(model_args, is_train=True)
model_engine.processor: Get the tokenizer or multi-modal processor.
model_engine.renderer: Get the renderer.
model_engine.model_config: Get the model configuration.
model_engine.model: Get the HF model.
Init Workflow:
Init workflow:
1. Init processor.
2. Init render.
2. Init model config.
3. Init model.
4. Init adapter.
@@ -36,17 +38,18 @@ from ..accelerator.interface import DistributedInterface
from ..config.model_args import ModelArguments, ModelClass
from ..utils import logging
from ..utils.types import HFConfig, HFModel, Processor
from .utils.rendering import Renderer
logger = logging.get_logger(__name__)
class ModelLoader:
"""Model loader.
class ModelEngine:
"""Model engine.
Args:
model_args: Model arguments.
is_trainable: Whether to train the model.
is_train: Whether to train the model.
"""
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
@@ -56,6 +59,8 @@ class ModelLoader:
"""Whether to train the model."""
self.processor = self._init_processor()
"""Tokenizer or multi-modal processor."""
self.renderer = Renderer(self.args.template, self.processor)
"""Renderer."""
self.model_config = self._init_model_config()
"""Model configuration."""
self.model = self._init_model()
@@ -107,7 +112,7 @@ class ModelLoader:
init_device = InitPlugin(self.args.init_config.name)()
else:
init_device = DistributedInterface().current_accelerator
init_device = DistributedInterface().current_device
if init_device.type == DeviceType.META:
with init_empty_weights():
@@ -144,12 +149,12 @@ class ModelLoader:
if __name__ == "__main__":
"""
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
"""
from ..config.arg_parser import get_args
_, model_args, *_ = get_args()
model_loader = ModelLoader(model_args=model_args)
print(model_loader.processor)
print(model_loader.model_config)
print(model_loader.model)
model_engine = ModelEngine(model_args=model_args)
print(model_engine.processor)
print(model_engine.model_config)
print(model_engine.model)

View File

@@ -0,0 +1,239 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def render_chatml_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
"""
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
def parse_chatml_message(generated_text: str) -> Message:
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
Args:
generated_text (str): The generated text in ChatML format.
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)
class Renderer:
def __init__(self, template: str, processor: Processor):
self.template = template
self.processor = processor
def render_messages(
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
) -> ModelInput:
if self.template == "chatml":
return render_chatml_messages(self.processor, messages, tools, is_generate)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
def parse_message(self, generated_text: str) -> Message:
if self.template == "chatml":
return parse_chatml_message(generated_text)
else:
from ...plugins.model_plugins.rendering import RenderingPlugin
return RenderingPlugin(self.template).parse_message(generated_text)