mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-12 17:10:36 +08:00
[v1] add renderer ut (#9722)
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -15,6 +15,7 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
@@ -25,14 +26,11 @@ from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
)
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
import warnings
|
||||
|
||||
from ..extras import logging
|
||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||
from ..extras.packages import _get_package_version
|
||||
from ..extras.packages import is_torch_version_greater_than
|
||||
from .adapter import init_adapter
|
||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||
from .model_utils.liger_kernel import apply_liger_kernel
|
||||
@@ -208,9 +206,8 @@ def load_model(
|
||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||
|
||||
# Conv3D is not recommended when using torch 2.9.x
|
||||
torch_version = _get_package_version("torch")
|
||||
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
|
||||
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
|
||||
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
|
||||
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
|
||||
raise ValueError(
|
||||
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
||||
"This combination is known to cause severe performance regression. "
|
||||
|
||||
@@ -87,7 +87,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
|
||||
self.compute_loss_func = dft_loss_func
|
||||
|
||||
|
||||
elif finetuning_args.use_eaft_loss:
|
||||
from ..trainer_utils import eaft_loss_func
|
||||
|
||||
@@ -95,7 +94,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||
)
|
||||
|
||||
|
||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||
verify_fp8_status(self.accelerator, training_args)
|
||||
|
||||
|
||||
@@ -634,7 +634,9 @@ def get_batch_logps(
|
||||
return logps, valid_length
|
||||
|
||||
|
||||
def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
||||
def dft_loss_func(
|
||||
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
|
||||
):
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
||||
|
||||
|
||||
def _dft_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
source: "torch.Tensor",
|
||||
target: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
) -> "torch.Tensor":
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
@@ -679,7 +681,12 @@ def _dft_cross_entropy(
|
||||
return loss
|
||||
|
||||
|
||||
def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
|
||||
def eaft_loss_func(
|
||||
outputs: "torch.Tensor",
|
||||
labels: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
alpha: float = 1.0,
|
||||
) -> "torch.Tensor":
|
||||
logits = outputs.get("logits")
|
||||
if logits is None:
|
||||
return outputs.get("loss", torch.tensor(0.0))
|
||||
@@ -697,12 +704,12 @@ def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
|
||||
|
||||
|
||||
def _eaft_cross_entropy(
|
||||
source: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||
source: "torch.Tensor",
|
||||
target: "torch.Tensor",
|
||||
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||
alpha: float = 1.0,
|
||||
ignore_index: int = -100,
|
||||
) -> torch.Tensor:
|
||||
) -> "torch.Tensor":
|
||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||
valid_mask = target != ignore_index
|
||||
if not valid_mask.any():
|
||||
@@ -731,6 +738,7 @@ def _eaft_cross_entropy(
|
||||
loss = total_loss / num_items_in_batch
|
||||
else:
|
||||
loss = weighted_losses.mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class ModelArguments:
|
||||
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||
)
|
||||
template: str = field(
|
||||
default="chatml",
|
||||
default="qwen3_nothink",
|
||||
metadata={"help": "Template for the model."},
|
||||
)
|
||||
trust_remote_code: bool = field(
|
||||
|
||||
@@ -12,38 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
|
||||
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def render_chatml_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
@@ -52,123 +26,38 @@ def render_chatml_messages(
|
||||
) -> ModelInput:
|
||||
"""Apply chatml template to messages and convert them to model input.
|
||||
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
|
||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
|
||||
"""
|
||||
tokenizer = get_tokenizer(processor)
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n"
|
||||
for content in messages[0]["content"]:
|
||||
for message in messages:
|
||||
temp_str = "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([0.0] * len(temp_ids))
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
attention_mask=[1] * len(input_ids),
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
@@ -183,36 +72,7 @@ def parse_chatml_message(generated_text: str) -> Message:
|
||||
Returns:
|
||||
Message: The parsed message.
|
||||
"""
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
|
||||
|
||||
|
||||
class Renderer:
|
||||
|
||||
@@ -158,7 +158,7 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
||||
)
|
||||
|
||||
if tools:
|
||||
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
|
||||
return {"messages": messages, "tools": json.dumps(tools)}
|
||||
else:
|
||||
return {"messages": messages}
|
||||
|
||||
|
||||
@@ -13,24 +13,200 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from ...utils.constants import IGNORE_INDEX
|
||||
from ...utils.helper import get_tokenizer
|
||||
from ...utils.plugin import BasePlugin
|
||||
from ...utils.types import Message, ModelInput, Processor
|
||||
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
||||
|
||||
|
||||
class RenderingPlugin(BasePlugin):
|
||||
pass
|
||||
|
||||
|
||||
@RenderingPlugin("qwen").register("render_messages")
|
||||
def _update_model_input(
|
||||
processor: Processor,
|
||||
input_ids: list[int],
|
||||
labels: list[int],
|
||||
loss_weights: list[int],
|
||||
temp_str: str,
|
||||
temp_weight: float,
|
||||
) -> str:
|
||||
"""Update model input with temporary string."""
|
||||
if not temp_str:
|
||||
return ""
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||
input_ids.extend(temp_ids)
|
||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||
if temp_weight > 1e-6:
|
||||
labels.extend(temp_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||
def render_qwen_messages(
|
||||
processor: Processor,
|
||||
messages: list[Message],
|
||||
tools: str | None = None,
|
||||
is_generate: bool = False,
|
||||
) -> ModelInput:
|
||||
raise NotImplementedError()
|
||||
input_ids, labels, loss_weights = [], [], []
|
||||
temp_str, temp_weight = "", 0.0
|
||||
if tools:
|
||||
temp_str += "<|im_start|>system\n"
|
||||
if messages[0]["role"] == "system":
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str += (
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||
|
||||
if not isinstance(tools, list):
|
||||
tools = [tools]
|
||||
|
||||
for tool in tools:
|
||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
temp_str += (
|
||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||
)
|
||||
elif messages[0]["role"] == "system":
|
||||
temp_str += "<|im_start|>system\n"
|
||||
for content in messages[0]["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
for turn_idx, message in enumerate(messages):
|
||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
elif message["role"] == "assistant":
|
||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||
for val_idx, content in enumerate(message["content"]):
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
elif content["type"] == "reasoning":
|
||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||
elif content["type"] == "tool_call":
|
||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||
temp_str += "\n"
|
||||
|
||||
try:
|
||||
tool_call: ToolCall = json.loads(content["value"])
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||
|
||||
temp_str += (
|
||||
'<tool_call>\n{"name": "'
|
||||
+ tool_call["name"]
|
||||
+ '", "arguments": '
|
||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||
+ "}\n</tool_call>"
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "<|im_end|>\n"
|
||||
temp_weight = message.get("loss_weight", 1.0)
|
||||
elif message["role"] == "tool":
|
||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||
temp_str += "<|im_start|>user"
|
||||
|
||||
temp_str += "\n<tool_response>\n"
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
temp_str += content["value"]
|
||||
else:
|
||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||
|
||||
temp_str += "\n</tool_response>"
|
||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||
temp_str += "<|im_end|>\n"
|
||||
|
||||
temp_weight = message.get("loss_weight", 0.0)
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
if is_generate:
|
||||
temp_str += "<|im_start|>assistant\n"
|
||||
temp_weight = 0.0
|
||||
|
||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||
|
||||
attention_mask = [1] * len(input_ids)
|
||||
return ModelInput(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
loss_weights=loss_weights,
|
||||
)
|
||||
|
||||
|
||||
@RenderingPlugin("qwen").register("parse_message")
|
||||
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||
def parse_qwen_message(generated_text: str) -> Message:
|
||||
raise NotImplementedError()
|
||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||
content = []
|
||||
last_end = 0
|
||||
for match in pattern.finditer(generated_text):
|
||||
start, end = match.span()
|
||||
if start > last_end:
|
||||
text = generated_text[last_end:start].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
tag_type = match.group(1)
|
||||
tag_value = match.group(2).strip()
|
||||
if tag_type == "thinking":
|
||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||
elif tag_type == "tool_call":
|
||||
try:
|
||||
json.loads(tag_value.strip())
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||
|
||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||
|
||||
last_end = end
|
||||
|
||||
if last_end < len(generated_text):
|
||||
text = generated_text[last_end:].strip()
|
||||
if text:
|
||||
content.append({"type": "text", "value": text})
|
||||
|
||||
return Message(role="assistant", content=content)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -102,8 +102,10 @@ class Message(TypedDict):
|
||||
class SFTSample(TypedDict):
|
||||
messages: list[Message]
|
||||
"""Messages in the sample."""
|
||||
tools: NotRequired[str]
|
||||
"""Tools for the sample in JSON string format."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, including tools, kto_labels."""
|
||||
"""Extra information for the sample, e.g. kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
@@ -113,8 +115,10 @@ class DPOSample(TypedDict):
|
||||
"""Chosen messages in the sample."""
|
||||
rejected_messages: list[Message]
|
||||
"""Rejected messages in the sample."""
|
||||
tools: NotRequired[str]
|
||||
"""Tools for the sample in JSON string format."""
|
||||
extra_info: NotRequired[str]
|
||||
"""Extra information for the sample, including tools, kto_labels."""
|
||||
"""Extra information for the sample, e.g. kto_labels."""
|
||||
_dataset_name: NotRequired[str]
|
||||
"""Dataset name for the sample."""
|
||||
|
||||
@@ -125,7 +129,7 @@ Sample = Union[SFTSample, DPOSample]
|
||||
class ToolCall(TypedDict):
|
||||
name: str
|
||||
"""Function name."""
|
||||
arguments: str
|
||||
arguments: dict[str, Any]
|
||||
"""Function arguments."""
|
||||
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.5.101
|
||||
0.9.5.103
|
||||
|
||||
@@ -12,8 +12,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from llamafactory.v1.config import DataArguments
|
||||
from llamafactory.v1.core.data_engine import DataEngine
|
||||
from llamafactory.v1.core.utils.rendering import Renderer
|
||||
from llamafactory.v1.utils.types import Processor
|
||||
|
||||
@@ -23,12 +28,54 @@ HF_MESSAGES = [
|
||||
{"role": "user", "content": "What is LLM?"},
|
||||
{"role": "assistant", "content": "LLM stands for Large Language Model."},
|
||||
]
|
||||
|
||||
V1_MESSAGES = [
|
||||
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
|
||||
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
|
||||
]
|
||||
|
||||
HF_MESSAGES_WITH_TOOLS = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is 6*8?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 6, "b": 8}}}],
|
||||
},
|
||||
{"role": "tool", "content": "48."},
|
||||
{"role": "assistant", "content": "The result of 6*8 is 48."},
|
||||
]
|
||||
|
||||
V1_MESSAGES_WITH_TOOLS = [
|
||||
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
|
||||
{"role": "user", "content": [{"type": "text", "value": "What is 6*8?"}]},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})}],
|
||||
"loss_weight": 0.0,
|
||||
},
|
||||
{"role": "tool", "content": [{"type": "text", "value": "48."}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "value": "The result of 6*8 is 48."}]},
|
||||
]
|
||||
|
||||
V1_TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "multiply",
|
||||
"description": "A function that multiplies two numbers",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "number", "description": "The first number to multiply"},
|
||||
"b": {"type": "number", "description": "The second number to multiply"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_chatml_rendering():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
@@ -60,6 +107,87 @@ def test_chatml_parse():
|
||||
assert parsed_message == V1_MESSAGES[-1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [16])
|
||||
def test_chatml_rendering_remote(num_samples: int):
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||
data_engine = DataEngine(data_args)
|
||||
for index in range(num_samples):
|
||||
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
|
||||
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
|
||||
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
|
||||
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
||||
|
||||
|
||||
def test_qwen3_nothink_rendering():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||
|
||||
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
|
||||
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
|
||||
assert v1_inputs["input_ids"] == hf_inputs
|
||||
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||
|
||||
hf_inputs_part = tokenizer.apply_chat_template(
|
||||
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False
|
||||
)
|
||||
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
|
||||
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
|
||||
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
|
||||
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
|
||||
len(hf_inputs_full) - len(hf_inputs_part)
|
||||
)
|
||||
|
||||
|
||||
def test_qwen3_nothink_parse():
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||
generated_text = (
|
||||
"<thinking>I need to use the multiply function to calculate 6*8.</thinking>"
|
||||
"Let me call the multiply function."
|
||||
'<tool_call>{"name": "multiply", "arguments": {"a": 6, "b": 8}}</tool_call>'
|
||||
)
|
||||
parsed_message = renderer.parse_message(generated_text)
|
||||
assert parsed_message == {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "reasoning", "value": "I need to use the multiply function to calculate 6*8."},
|
||||
{"type": "text", "value": "Let me call the multiply function."},
|
||||
{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_samples", [8])
|
||||
def test_qwen3_nothink_rendering_remote(num_samples: int):
|
||||
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
|
||||
data_engine = DataEngine(data_args)
|
||||
for index in range(num_samples):
|
||||
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
|
||||
prefix_text = (
|
||||
"<|im_start|>system\nYou are a methodical and expert assistant. "
|
||||
"Your primary goal is to solve user requests by leveraging a set of available tools. "
|
||||
"You must reason for the best course of action in a structured manner before responding.\n\n"
|
||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>\n"
|
||||
'{"type": "function", "function": {"name":'
|
||||
)
|
||||
prefix = tokenizer.encode(prefix_text, add_special_tokens=False)
|
||||
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
|
||||
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatml_rendering()
|
||||
test_chatml_parse()
|
||||
test_chatml_rendering_remote(16)
|
||||
test_qwen3_nothink_rendering()
|
||||
test_qwen3_nothink_parse()
|
||||
test_qwen3_nothink_rendering_remote(16)
|
||||
|
||||
@@ -61,11 +61,11 @@ def test_sharegpt_converter():
|
||||
}
|
||||
expected_data = {
|
||||
"messages": [
|
||||
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
|
||||
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
|
||||
{"content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0, "role": "assistant"},
|
||||
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
|
||||
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
|
||||
{"role": "system", "content": [{"type": "text", "value": "System"}], "loss_weight": 0.0},
|
||||
{"role": "user", "content": [{"type": "text", "value": "User"}], "loss_weight": 0.0},
|
||||
{"role": "assistant", "content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0},
|
||||
{"role": "tool", "content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0},
|
||||
{"role": "assistant", "content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0},
|
||||
]
|
||||
}
|
||||
assert DataConverterPlugin("sharegpt")(example) == expected_data
|
||||
|
||||
@@ -21,7 +21,7 @@ from llamafactory.v1.samplers.cli_sampler import SyncSampler
|
||||
|
||||
@pytest.mark.runs_on(["cuda", "npu"])
|
||||
def test_sync_sampler():
|
||||
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507")
|
||||
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507", template="qwen3_nothink")
|
||||
sample_args = SampleArguments()
|
||||
model_engine = ModelEngine(model_args)
|
||||
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||
|
||||
Reference in New Issue
Block a user