[v1] add renderer ut (#9722)

This commit is contained in:
Yaowei Zheng
2026-01-07 02:06:07 +08:00
committed by GitHub
parent ea0b4e2466
commit d22de0d4bf
13 changed files with 420 additions and 249 deletions

View File

@@ -15,6 +15,7 @@
import os
from typing import TYPE_CHECKING, Any, Optional, TypedDict
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
@@ -25,14 +26,11 @@ from transformers import (
AutoProcessor,
AutoTokenizer,
)
from packaging import version
from torch import nn
from trl import AutoModelForCausalLMWithValueHead
import warnings
from ..extras import logging
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
from ..extras.packages import _get_package_version
from ..extras.packages import is_torch_version_greater_than
from .adapter import init_adapter
from .model_utils.ktransformers import load_kt_pretrained_model
from .model_utils.liger_kernel import apply_liger_kernel
@@ -206,11 +204,10 @@ def load_model(
if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
# Conv3D is not recommended when using torch 2.9.x
torch_version = _get_package_version("torch")
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
raise ValueError(
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
"This combination is known to cause severe performance regression. "

View File

@@ -87,7 +87,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss:
from ..trainer_utils import eaft_loss_func
@@ -95,7 +94,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
verify_fp8_status(self.accelerator, training_args)

View File

@@ -634,7 +634,9 @@ def get_batch_logps(
return logps, valid_length
def dft_loss_func(outputs, labels, num_items_in_batch=None):
def dft_loss_func(
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
):
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
def _dft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
ignore_index: int = -100,
) -> torch.Tensor:
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
@@ -679,7 +681,12 @@ def _dft_cross_entropy(
return loss
def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
def eaft_loss_func(
outputs: "torch.Tensor",
labels: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
) -> "torch.Tensor":
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))
@@ -697,12 +704,12 @@ def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
def _eaft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
source: "torch.Tensor",
target: "torch.Tensor",
num_items_in_batch: Optional["torch.Tensor"] = None,
alpha: float = 1.0,
ignore_index: int = -100,
) -> torch.Tensor:
) -> "torch.Tensor":
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
@@ -712,13 +719,13 @@ def _eaft_cross_entropy(
with torch.no_grad():
source_detached = source[valid_mask].detach()
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
log_probs_topk = topk_val - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
entropy_term = entropy_approx / 3.0
adaptive_weight = torch.pow(entropy_term, alpha)
@@ -731,6 +738,7 @@ def _eaft_cross_entropy(
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()
return loss

View File

@@ -25,7 +25,7 @@ class ModelArguments:
metadata={"help": "Path to the model or model identifier from Hugging Face."},
)
template: str = field(
default="chatml",
default="qwen3_nothink",
metadata={"help": "Template for the model."},
)
trust_remote_code: bool = field(

View File

@@ -12,38 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import re
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.types import Message, ModelInput, Processor
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
def render_chatml_messages(
processor: Processor,
messages: list[Message],
@@ -52,123 +26,38 @@ def render_chatml_messages(
) -> ModelInput:
"""Apply chatml template to messages and convert them to model input.
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
"""
tokenizer = get_tokenizer(processor)
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
for message in messages:
temp_str = "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([0.0] * len(temp_ids))
labels.extend([IGNORE_INDEX] * len(temp_ids))
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
attention_mask=[1] * len(input_ids),
labels=labels,
loss_weights=loss_weights,
)
@@ -183,36 +72,7 @@ def parse_chatml_message(generated_text: str) -> Message:
Returns:
Message: The parsed message.
"""
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
class Renderer:

View File

@@ -158,7 +158,7 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
)
if tools:
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
return {"messages": messages, "tools": json.dumps(tools)}
else:
return {"messages": messages}

View File

@@ -13,24 +13,200 @@
# limitations under the License.
import json
import re
from ...utils.constants import IGNORE_INDEX
from ...utils.helper import get_tokenizer
from ...utils.plugin import BasePlugin
from ...utils.types import Message, ModelInput, Processor
from ...utils.types import Message, ModelInput, Processor, ToolCall
class RenderingPlugin(BasePlugin):
pass
@RenderingPlugin("qwen").register("render_messages")
def _update_model_input(
processor: Processor,
input_ids: list[int],
labels: list[int],
loss_weights: list[int],
temp_str: str,
temp_weight: float,
) -> str:
"""Update model input with temporary string."""
if not temp_str:
return ""
tokenizer = get_tokenizer(processor)
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
input_ids.extend(temp_ids)
loss_weights.extend([temp_weight] * len(temp_ids))
if temp_weight > 1e-6:
labels.extend(temp_ids)
else:
labels.extend([IGNORE_INDEX] * len(temp_ids))
return ""
@RenderingPlugin("qwen3_nothink").register("render_messages")
def render_qwen_messages(
processor: Processor,
messages: list[Message],
tools: str | None = None,
is_generate: bool = False,
) -> ModelInput:
raise NotImplementedError()
input_ids, labels, loss_weights = [], [], []
temp_str, temp_weight = "", 0.0
if tools:
temp_str += "<|im_start|>system\n"
if messages[0]["role"] == "system":
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str += (
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
)
try:
tools = json.loads(tools)
except json.JSONDecodeError:
raise ValueError(f"Invalid tools format: {str(tools)}.")
if not isinstance(tools, list):
tools = [tools]
for tool in tools:
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
temp_str += (
"\n</tools>\n\nFor each function call, return a json object with function name "
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
)
elif messages[0]["role"] == "system":
temp_str += "<|im_start|>system\n"
for content in messages[0]["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = messages[0].get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
for turn_idx, message in enumerate(messages):
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
temp_str += "<|im_start|>" + message["role"] + "\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
elif message["role"] == "assistant":
temp_str += "<|im_start|>" + message["role"] + "\n"
for val_idx, content in enumerate(message["content"]):
if content["type"] == "text":
temp_str += content["value"]
elif content["type"] == "reasoning":
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
elif content["type"] == "tool_call":
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
temp_str += "\n"
try:
tool_call: ToolCall = json.loads(content["value"])
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {content['value']}.")
temp_str += (
'<tool_call>\n{"name": "'
+ tool_call["name"]
+ '", "arguments": '
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
+ "}\n</tool_call>"
)
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 1.0)
elif message["role"] == "tool":
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
temp_str += "<|im_start|>user"
temp_str += "\n<tool_response>\n"
for content in message["content"]:
if content["type"] == "text":
temp_str += content["value"]
else:
raise ValueError(f"Unsupported content type: {content['type']}")
temp_str += "\n</tool_response>"
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
temp_str += "<|im_end|>\n"
temp_weight = message.get("loss_weight", 0.0)
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
if is_generate:
temp_str += "<|im_start|>assistant\n"
temp_weight = 0.0
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
attention_mask = [1] * len(input_ids)
return ModelInput(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
loss_weights=loss_weights,
)
@RenderingPlugin("qwen").register("parse_message")
@RenderingPlugin("qwen3_nothink").register("parse_message")
def parse_qwen_message(generated_text: str) -> Message:
raise NotImplementedError()
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
content = []
last_end = 0
for match in pattern.finditer(generated_text):
start, end = match.span()
if start > last_end:
text = generated_text[last_end:start].strip()
if text:
content.append({"type": "text", "value": text})
tag_type = match.group(1)
tag_value = match.group(2).strip()
if tag_type == "thinking":
content.append({"type": "reasoning", "value": tag_value.strip()})
elif tag_type == "tool_call":
try:
json.loads(tag_value.strip())
except json.JSONDecodeError:
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
content.append({"type": "tool_call", "value": tag_value.strip()})
last_end = end
if last_end < len(generated_text):
text = generated_text[last_end:].strip()
if text:
content.append({"type": "text", "value": text})
return Message(role="assistant", content=content)

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
if TYPE_CHECKING:
@@ -102,8 +102,10 @@ class Message(TypedDict):
class SFTSample(TypedDict):
messages: list[Message]
"""Messages in the sample."""
tools: NotRequired[str]
"""Tools for the sample in JSON string format."""
extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels."""
"""Extra information for the sample, e.g. kto_labels."""
_dataset_name: NotRequired[str]
"""Dataset name for the sample."""
@@ -113,8 +115,10 @@ class DPOSample(TypedDict):
"""Chosen messages in the sample."""
rejected_messages: list[Message]
"""Rejected messages in the sample."""
tools: NotRequired[str]
"""Tools for the sample in JSON string format."""
extra_info: NotRequired[str]
"""Extra information for the sample, including tools, kto_labels."""
"""Extra information for the sample, e.g. kto_labels."""
_dataset_name: NotRequired[str]
"""Dataset name for the sample."""
@@ -125,7 +129,7 @@ Sample = Union[SFTSample, DPOSample]
class ToolCall(TypedDict):
name: str
"""Function name."""
arguments: str
arguments: dict[str, Any]
"""Function arguments."""