[v1] add renderer ut (#9722)

This commit is contained in:
Yaowei Zheng
2026-01-07 02:06:07 +08:00
committed by GitHub
parent ea0b4e2466
commit d22de0d4bf
13 changed files with 420 additions and 249 deletions

File diff suppressed because one or more lines are too long

View File

@@ -15,6 +15,7 @@
import os import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
@@ -25,14 +26,11 @@ from transformers import (
AutoProcessor, AutoProcessor,
AutoTokenizer, AutoTokenizer,
) )
from packaging import version
from torch import nn
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
import warnings
from ..extras import logging from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import _get_package_version from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel from .model_utils.liger_kernel import apply_liger_kernel
@@ -206,11 +204,10 @@ def load_model(
if vhead_params is not None: if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False) model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}") logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
# Conv3D is not recommended when using torch 2.9.x # Conv3D is not recommended when using torch 2.9.x
torch_version = _get_package_version("torch") if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"): if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
raise ValueError( raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. " "Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. " "This combination is known to cause severe performance regression. "

View File

@@ -87,7 +87,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func self.compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss: elif finetuning_args.use_eaft_loss:
from ..trainer_utils import eaft_loss_func from ..trainer_utils import eaft_loss_func
@@ -95,7 +94,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
) )
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args) verify_fp8_status(self.accelerator, training_args)

View File

@@ -634,7 +634,9 @@ def get_batch_logps(
return logps, valid_length return logps, valid_length
def dft_loss_func(outputs, labels, num_items_in_batch=None): def dft_loss_func(
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
):
logits = outputs.get("logits") logits = outputs.get("logits")
if logits is None: if logits is None:
return outputs.get("loss", torch.tensor(0.0)) return outputs.get("loss", torch.tensor(0.0))
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
def _dft_cross_entropy( def _dft_cross_entropy(
source: torch.Tensor, source: "torch.Tensor",
target: torch.Tensor, target: "torch.Tensor",
num_items_in_batch: Optional[torch.Tensor] = None, num_items_in_batch: Optional["torch.Tensor"] = None,
ignore_index: int = -100, ignore_index: int = -100,
) -> torch.Tensor: ) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none") per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index valid_mask = target != ignore_index
if not valid_mask.any(): if not valid_mask.any():
@@ -679,7 +681,12 @@ def _dft_cross_entropy(
return loss return loss
def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0): def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
) -> "torch.Tensor":
logits = outputs.get("logits") logits = outputs.get("logits")
if logits is None: if logits is None:
return outputs.get("loss", torch.tensor(0.0)) return outputs.get("loss", torch.tensor(0.0))
@@ -697,12 +704,12 @@ def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
def _eaft_cross_entropy( def _eaft_cross_entropy(
source: torch.Tensor, source: "torch.Tensor",
target: torch.Tensor, target: "torch.Tensor",
num_items_in_batch: Optional[torch.Tensor] = None, num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0, alpha: float = 1.0,
ignore_index: int = -100, ignore_index: int = -100,
) -> torch.Tensor: ) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none") per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index valid_mask = target != ignore_index
if not valid_mask.any(): if not valid_mask.any():
@@ -712,13 +719,13 @@ def _eaft_cross_entropy(
with torch.no_grad(): with torch.no_grad():
source_detached = source[valid_mask].detach() source_detached = source[valid_mask].detach()
topk_val, _ = torch.topk(source_detached, k=20, dim=-1) topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True) logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
log_probs_topk = topk_val - logsumexp_topk log_probs_topk = topk_val - logsumexp_topk
probs_topk = torch.exp(log_probs_topk) probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1) entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
entropy_term = entropy_approx / 3.0 entropy_term = entropy_approx / 3.0
adaptive_weight = torch.pow(entropy_term, alpha) adaptive_weight = torch.pow(entropy_term, alpha)
@@ -731,6 +738,7 @@ def _eaft_cross_entropy(
loss = total_loss / num_items_in_batch loss = total_loss / num_items_in_batch
else: else:
loss = weighted_losses.mean() loss = weighted_losses.mean()
return loss return loss

View File

@@ -25,7 +25,7 @@ class ModelArguments:
metadata={"help": "Path to the model or model identifier from Hugging Face."}, metadata={"help": "Path to the model or model identifier from Hugging Face."},
) )
template: str = field( template: str = field(
default="chatml", default="qwen3_nothink",
metadata={"help": "Template for the model."}, metadata={"help": "Template for the model."},
) )
trust_remote_code: bool = field( trust_remote_code: bool = field(

View File

@@ -12,38 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import re
from ...utils.constants import IGNORE_INDEX from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor from ...utils.types import Message, ModelInput, Processor
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def render_chatml_messages( def render_chatml_messages(
processor: Processor, processor: Processor,
messages: list[Message], messages: list[Message],
@@ -52,123 +26,38 @@ def render_chatml_messages(
) -> ModelInput: ) -> ModelInput:
"""Apply chatml template to messages and convert them to model input. """Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
""" """
tokenizer = get_tokenizer(processor)
input_ids, labels, loss_weights = [], [], [] input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n" for message in messages:
temp_weight = messages[0].get("loss_weight", 0.0) temp_str = "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text": if content["type"] == "text":
temp_str += content["value"] temp_str += content["value"]
else: else:
raise ValueError(f"Unsupported content type: {content['type']}") raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n" temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0) temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
for turn_idx, message in enumerate(messages): if temp_weight > 1e-6:
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0): labels.extend(temp_ids)
temp_str += "<|im_start|>" + message["role"] + "\n" else:
for content in message["content"]: labels.extend([IGNORE_INDEX] * len(temp_ids))
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate: if is_generate:
temp_str += "<|im_start|>assistant\n" temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
temp_weight = 0.0 input_ids.extend(temp_ids)
loss_weights.extend([0.0] * len(temp_ids))
labels.extend([IGNORE_INDEX] * len(temp_ids))
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput( return ModelInput(
input_ids=input_ids, input_ids=input_ids,
attention_mask=attention_mask, attention_mask=[1] * len(input_ids),
labels=labels, labels=labels,
loss_weights=loss_weights, loss_weights=loss_weights,
) )
@@ -183,36 +72,7 @@ def parse_chatml_message(generated_text: str) -> Message:
Returns: Returns:
Message: The parsed message. Message: The parsed message.
""" """
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL) return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)
class Renderer: class Renderer:

View File

@@ -158,7 +158,7 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
) )
if tools: if tools:
return {"messages": messages, "extra_info": json.dumps({"tools": tools})} return {"messages": messages, "tools": json.dumps(tools)}
else: else:
return {"messages": messages} return {"messages": messages}

View File

@@ -13,24 +13,200 @@
# limitations under the License. # limitations under the License.
import json
import re
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.plugin import BasePlugin from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor from ...utils.types import Message, ModelInput, Processor, ToolCall
class RenderingPlugin(BasePlugin): class RenderingPlugin(BasePlugin):
pass pass
@RenderingPlugin("qwen").register("render_messages") def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen_messages( def render_qwen_messages(
processor: Processor, processor: Processor,
messages: list[Message], messages: list[Message],
tools: str | None = None, tools: str | None = None,
is_generate: bool = False, is_generate: bool = False,
) -> ModelInput: ) -> ModelInput:
raise NotImplementedError() input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen").register("parse_message") @RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message: def parse_qwen_message(generated_text: str) -> Message:
raise NotImplementedError() pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -102,8 +102,10 @@ class Message(TypedDict):
class SFTSample(TypedDict): class SFTSample(TypedDict):
messages: list[Message] messages: list[Message]
"""Messages in the sample.""" """Messages in the sample."""
tools: NotRequired[str]
"""Tools for the sample in JSON string format."""
extra_info: NotRequired[str] extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels.""" """Extra information for the sample, e.g. kto_labels."""
_dataset_name: NotRequired[str] _dataset_name: NotRequired[str]
"""Dataset name for the sample.""" """Dataset name for the sample."""
@@ -113,8 +115,10 @@ class DPOSample(TypedDict):
"""Chosen messages in the sample.""" """Chosen messages in the sample."""
rejected_messages: list[Message] rejected_messages: list[Message]
"""Rejected messages in the sample.""" """Rejected messages in the sample."""
tools: NotRequired[str]
"""Tools for the sample in JSON string format."""
extra_info: NotRequired[str] extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels.""" """Extra information for the sample, e.g. kto_labels."""
_dataset_name: NotRequired[str] _dataset_name: NotRequired[str]
"""Dataset name for the sample.""" """Dataset name for the sample."""
@@ -125,7 +129,7 @@ Sample = Union[SFTSample, DPOSample]
class ToolCall(TypedDict): class ToolCall(TypedDict):
name: str name: str
"""Function name.""" """Function name."""
arguments: str arguments: dict[str, Any]
"""Function arguments.""" """Function arguments."""

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.5.101 0.9.5.103

View File

@@ -12,8 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from llamafactory.v1.config import DataArguments
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.utils.rendering import Renderer from llamafactory.v1.core.utils.rendering import Renderer
from llamafactory.v1.utils.types import Processor from llamafactory.v1.utils.types import Processor
@@ -23,12 +28,54 @@ HF_MESSAGES = [
{"role": "user", "content": "What is LLM?"}, {"role": "user", "content": "What is LLM?"},
{"role": "assistant", "content": "LLM stands for Large Language Model."}, {"role": "assistant", "content": "LLM stands for Large Language Model."},
] ]
V1_MESSAGES = [ V1_MESSAGES = [
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]}, {"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]}, {"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]}, {"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
] ]
HF_MESSAGES_WITH_TOOLS = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 6*8?"},
{
"role": "assistant",
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 6, "b": 8}}}],
},
{"role": "tool", "content": "48."},
{"role": "assistant", "content": "The result of 6*8 is 48."},
]
V1_MESSAGES_WITH_TOOLS = [
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
{"role": "user", "content": [{"type": "text", "value": "What is 6*8?"}]},
{
"role": "assistant",
"content": [{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})}],
"loss_weight": 0.0,
},
{"role": "tool", "content": [{"type": "text", "value": "48."}]},
{"role": "assistant", "content": [{"type": "text", "value": "The result of 6*8 is 48."}]},
]
V1_TOOLS = [
{
"type": "function",
"function": {
"name": "multiply",
"description": "A function that multiplies two numbers",
"parameters": {
"type": "object",
"properties": {
"a": {"type": "number", "description": "The first number to multiply"},
"b": {"type": "number", "description": "The second number to multiply"},
},
"required": ["a", "b"],
},
},
}
]
def test_chatml_rendering(): def test_chatml_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3") tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
@@ -60,6 +107,87 @@ def test_chatml_parse():
assert parsed_message == V1_MESSAGES[-1] assert parsed_message == V1_MESSAGES[-1]
@pytest.mark.parametrize("num_samples", [16])
def test_chatml_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
renderer = Renderer(template="chatml", processor=tokenizer)
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
data_engine = DataEngine(data_args)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
assert v1_inputs["input_ids"][: len(prefix)] == prefix
def test_qwen3_nothink_rendering():
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
assert v1_inputs["input_ids"] == hf_inputs
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
hf_inputs_part = tokenizer.apply_chat_template(
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False
)
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
assert v1_inputs_full["input_ids"] == hf_inputs_full
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
len(hf_inputs_full) - len(hf_inputs_part)
)
def test_qwen3_nothink_parse():
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
generated_text = (
"<thinking>I need to use the multiply function to calculate 6*8.</thinking>"
"Let me call the multiply function."
'<tool_call>{"name": "multiply", "arguments": {"a": 6, "b": 8}}</tool_call>'
)
parsed_message = renderer.parse_message(generated_text)
assert parsed_message == {
"role": "assistant",
"content": [
{"type": "reasoning", "value": "I need to use the multiply function to calculate 6*8."},
{"type": "text", "value": "Let me call the multiply function."},
{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})},
],
}
@pytest.mark.parametrize("num_samples", [8])
def test_qwen3_nothink_rendering_remote(num_samples: int):
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
data_engine = DataEngine(data_args)
for index in range(num_samples):
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
prefix_text = (
"<|im_start|>system\nYou are a methodical and expert assistant. "
"Your primary goal is to solve user requests by leveraging a set of available tools. "
"You must reason for the best course of action in a structured manner before responding.\n\n"
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>\n"
'{"type": "function", "function": {"name":'
)
prefix = tokenizer.encode(prefix_text, add_special_tokens=False)
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
assert v1_inputs["input_ids"][: len(prefix)] == prefix
if __name__ == "__main__": if __name__ == "__main__":
test_chatml_rendering() test_chatml_rendering()
test_chatml_parse() test_chatml_parse()
test_chatml_rendering_remote(16)
test_qwen3_nothink_rendering()
test_qwen3_nothink_parse()
test_qwen3_nothink_rendering_remote(16)

View File

@@ -61,11 +61,11 @@ def test_sharegpt_converter():
} }
expected_data = { expected_data = {
"messages": [ "messages": [
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"}, {"role": "system", "content": [{"type": "text", "value": "System"}], "loss_weight": 0.0},
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"}, {"role": "user", "content": [{"type": "text", "value": "User"}], "loss_weight": 0.0},
{"content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0, "role": "assistant"}, {"role": "assistant", "content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0},
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"}, {"role": "tool", "content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0},
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"}, {"role": "assistant", "content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0},
] ]
} }
assert DataConverterPlugin("sharegpt")(example) == expected_data assert DataConverterPlugin("sharegpt")(example) == expected_data

View File

@@ -21,7 +21,7 @@ from llamafactory.v1.samplers.cli_sampler import SyncSampler
@pytest.mark.runs_on(["cuda", "npu"]) @pytest.mark.runs_on(["cuda", "npu"])
def test_sync_sampler(): def test_sync_sampler():
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507") model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507", template="qwen3_nothink")
sample_args = SampleArguments() sample_args = SampleArguments()
model_engine = ModelEngine(model_args) model_engine = ModelEngine(model_args)
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer) sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)