mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-24 06:39:08 +08:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -12,10 +12,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from ..config import ModelArguments, SampleArguments, SampleBackend
|
||||
from ..utils.types import HFModel, Processor, TorchDataset
|
||||
from ..utils.helper import get_tokenizer
|
||||
from ..utils.types import HFModel, Message, Sample, TorchDataset
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
@@ -24,8 +34,8 @@ class BaseEngine(ABC):
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel = None,
|
||||
processor: Processor = None,
|
||||
model: HFModel,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
"""Initialize the engine.
|
||||
|
||||
@@ -33,17 +43,34 @@ class BaseEngine(ABC):
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
processor: Processor.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def generate(self, messages):
|
||||
pass
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def batch_infer(self, data: TorchDataset) -> None:
|
||||
pass
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class HuggingFaceEngine(BaseEngine):
|
||||
@@ -52,26 +79,103 @@ class HuggingFaceEngine(BaseEngine):
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.model_args = model_args
|
||||
self.model = model
|
||||
self.renderer = renderer
|
||||
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_response(self, messages: list[Message], tools: str | None = None) -> Generator[str, None, None]:
|
||||
model_inputs = self.renderer.render_messages(messages, tools, is_generate=True)
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer=get_tokenizer(self.renderer.processor),
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True, # TODO: configurable
|
||||
)
|
||||
device = DistributedInterface().current_device
|
||||
kwargs = {
|
||||
"input_ids": torch.tensor([model_inputs["input_ids"]]).to(device),
|
||||
"attention_mask": torch.tensor([model_inputs["attention_mask"]]).to(device),
|
||||
"max_new_tokens": self.args.max_new_tokens,
|
||||
"streamer": streamer,
|
||||
}
|
||||
thread = Thread(target=self.model.generate, kwargs=kwargs, daemon=True)
|
||||
thread.start()
|
||||
|
||||
def stream():
|
||||
try:
|
||||
return streamer.__next__()
|
||||
except StopIteration:
|
||||
raise StopAsyncIteration()
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
async with self.semaphore:
|
||||
response = self.get_response(messages, tools)
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.to_thread(response)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
raise NotImplementedError("Batch infer is not implemented.")
|
||||
|
||||
|
||||
class BaseSampler:
|
||||
"""Base sampler.
|
||||
|
||||
Args:
|
||||
args: Sample arguments.
|
||||
model_args: Model arguments.
|
||||
model: Model.
|
||||
renderer: Renderer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: SampleArguments,
|
||||
model_args: ModelArguments,
|
||||
model: HFModel,
|
||||
processor: Processor,
|
||||
renderer: Renderer,
|
||||
) -> None:
|
||||
if args.sample_backend == SampleBackend.HF:
|
||||
self.engine = HuggingFaceEngine(args, model_args, model, processor)
|
||||
self.engine = HuggingFaceEngine(args, model_args, model, renderer)
|
||||
else:
|
||||
raise ValueError(f"Unknown sample backend: {args.sample_backend}")
|
||||
|
||||
async def generate(self, messages):
|
||||
return await self.engine.generate(messages)
|
||||
async def generate(self, messages: list[Message], tools: str | None = None) -> AsyncGenerator[str, None]:
|
||||
"""Generate tokens asynchronously.
|
||||
|
||||
async def batch_infer(self, data: TorchDataset) -> None:
|
||||
return await self.engine.batch_infer(data)
|
||||
Args:
|
||||
messages: List of messages.
|
||||
tools: Tools string.
|
||||
|
||||
Yields:
|
||||
Generated tokens.
|
||||
"""
|
||||
async for token in self.engine.generate(messages, tools):
|
||||
yield token
|
||||
|
||||
async def batch_infer(self, dataset: TorchDataset) -> list[Sample]:
|
||||
"""Batch infer samples.
|
||||
|
||||
Args:
|
||||
dataset: Torch dataset.
|
||||
|
||||
Returns:
|
||||
List of samples.
|
||||
"""
|
||||
return await self.engine.batch_infer(dataset)
|
||||
|
||||
@@ -14,15 +14,23 @@
|
||||
|
||||
"""The definition of data engine.
|
||||
|
||||
Init Data engine:
|
||||
How to use:
|
||||
data_engine = DataEngine(data_args)
|
||||
data_engine[i]: Get the sample via index.
|
||||
|
||||
Init workflow:
|
||||
1. Parse dataset info from arguments.
|
||||
2. Load datasets according to dataset info.
|
||||
3. Build data index (and reweight samples if necessary).
|
||||
|
||||
Get Data Sample:
|
||||
Get data sample:
|
||||
1. Get sample from data index.
|
||||
2. Convert sample to standard format.
|
||||
3. Return sample.
|
||||
|
||||
Note:
|
||||
1. The data engine is equivalent to the torch dataset.
|
||||
2. The data engine is agnostic to the model used.
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -98,10 +106,10 @@ class DataEngine(Dataset):
|
||||
|
||||
size = self.dataset_infos[dataset_name].get("size")
|
||||
weight = self.dataset_infos[dataset_name].get("weight")
|
||||
if size or weight: # data index plugin
|
||||
from ..plugins.data_plugins.loader import DataIndexPlugin
|
||||
if size or weight:
|
||||
from ..plugins.data_plugins.loader import adjust_data_index
|
||||
|
||||
data_index = DataIndexPlugin().adjust_data_index(data_index, size, weight)
|
||||
data_index = adjust_data_index(data_index, size, weight)
|
||||
|
||||
self.data_index.extend(data_index)
|
||||
|
||||
@@ -150,9 +158,9 @@ class DataEngine(Dataset):
|
||||
dataset_name, sample_index = self.data_index[index]
|
||||
return self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
else: # data selector plugin
|
||||
from ..plugins.data_plugins.loader import DataSelectorPlugin
|
||||
from ..plugins.data_plugins.loader import select_data_sample
|
||||
|
||||
selected_index = DataSelectorPlugin().select(self.data_index, index)
|
||||
selected_index = select_data_sample(self.data_index, index)
|
||||
if isinstance(selected_index, list):
|
||||
return [
|
||||
self._convert_data_sample(self.datasets[dataset_name][sample_index], dataset_name)
|
||||
|
||||
@@ -12,16 +12,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The definition of model loader.
|
||||
"""The definition of model engine.
|
||||
|
||||
How to use:
|
||||
model_loader = ModelLoader(model_args, is_trainable=True)
|
||||
model_loader.processor: Get the tokenizer or multi-modal processor.
|
||||
model_loader.model_config: Get the model configuration.
|
||||
model_loader.model: Get the HF model.
|
||||
model_engine = ModelEngine(model_args, is_train=True)
|
||||
model_engine.processor: Get the tokenizer or multi-modal processor.
|
||||
model_engine.renderer: Get the renderer.
|
||||
model_engine.model_config: Get the model configuration.
|
||||
model_engine.model: Get the HF model.
|
||||
|
||||
Init Workflow:
|
||||
Init workflow:
|
||||
1. Init processor.
|
||||
2. Init render.
|
||||
2. Init model config.
|
||||
3. Init model.
|
||||
4. Init adapter.
|
||||
@@ -36,17 +38,18 @@ from ..accelerator.interface import DistributedInterface
|
||||
from ..config.model_args import ModelArguments, ModelClass
|
||||
from ..utils import logging
|
||||
from ..utils.types import HFConfig, HFModel, Processor
|
||||
from .utils.rendering import Renderer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class ModelLoader:
|
||||
"""Model loader.
|
||||
class ModelEngine:
|
||||
"""Model engine.
|
||||
|
||||
Args:
|
||||
model_args: Model arguments.
|
||||
is_trainable: Whether to train the model.
|
||||
is_train: Whether to train the model.
|
||||
"""
|
||||
|
||||
def __init__(self, model_args: ModelArguments, is_train: bool = False) -> None:
|
||||
@@ -56,6 +59,8 @@ class ModelLoader:
|
||||
"""Whether to train the model."""
|
||||
self.processor = self._init_processor()
|
||||
"""Tokenizer or multi-modal processor."""
|
||||
self.renderer = Renderer(self.args.template, self.processor)
|
||||
"""Renderer."""
|
||||
self.model_config = self._init_model_config()
|
||||
"""Model configuration."""
|
||||
self.model = self._init_model()
|
||||
@@ -107,7 +112,7 @@ class ModelLoader:
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_accelerator
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
if init_device.type == DeviceType.META:
|
||||
with init_empty_weights():
|
||||
@@ -144,12 +149,12 @@ class ModelLoader:
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
python -m llamafactory.v1.core.model_loader --model llamafactory/tiny-random-qwen2.5
|
||||
python -m llamafactory.v1.core.model_engine --model llamafactory/tiny-random-qwen2.5
|
||||
"""
|
||||
from ..config.arg_parser import get_args
|
||||
|
||||
_, model_args, *_ = get_args()
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
print(model_loader.processor)
|
||||
print(model_loader.model_config)
|
||||
print(model_loader.model)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
print(model_engine.processor)
|
||||
print(model_engine.model_config)
|
||||
print(model_engine.model)
|
||||
239
src/llamafactory/v1/core/utils/rendering.py
Normal file
239
src/llamafactory/v1/core/utils/rendering.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
|
||||
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def render_chatml_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
"""Apply chatml template to messages and convert them to model input.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
|
||||
"""
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n"
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
def parse_chatml_message(generated_text: str) -> Message:
|
||||
"""Parse a message in ChatML format. Supports interleaved reasoning and tool calls.
|
||||
|
||||
Args:
|
||||
generated_text (str): The generated text in ChatML format.
|
||||
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
|
||||
|
||||
class Renderer:
|
||||
def __init__(self, template: str, processor: Processor):
|
||||
self.template = template
|
||||
self.processor = processor
|
||||
|
||||
def render_messages(
|
||||
self, messages: list[Message], tools: str | None = None, is_generate: bool = False
|
||||
) -> ModelInput:
|
||||
if self.template == "chatml":
|
||||
return render_chatml_messages(self.processor, messages, tools, is_generate)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).render_messages(self.processor, messages, tools, is_generate)
|
||||
|
||||
def parse_message(self, generated_text: str) -> Message:
|
||||
if self.template == "chatml":
|
||||
return parse_chatml_message(generated_text)
|
||||
else:
|
||||
from ...plugins.model_plugins.rendering import RenderingPlugin
|
||||
|
||||
return RenderingPlugin(self.template).parse_message(generated_text)
|
||||
Reference in New Issue
Block a user