mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-12 17:10:36 +08:00
[v1] add renderer ut (#9722)
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -15,6 +15,7 @@
|
|||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
from typing import TYPE_CHECKING, Any, Optional, TypedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
@@ -25,14 +26,11 @@ from transformers import (
|
|||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
)
|
)
|
||||||
from packaging import version
|
|
||||||
from torch import nn
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
import warnings
|
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||||
from ..extras.packages import _get_package_version
|
from ..extras.packages import is_torch_version_greater_than
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .model_utils.ktransformers import load_kt_pretrained_model
|
from .model_utils.ktransformers import load_kt_pretrained_model
|
||||||
from .model_utils.liger_kernel import apply_liger_kernel
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
@@ -206,11 +204,10 @@ def load_model(
|
|||||||
if vhead_params is not None:
|
if vhead_params is not None:
|
||||||
model.load_state_dict(vhead_params, strict=False)
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||||
|
|
||||||
# Conv3D is not recommended when using torch 2.9.x
|
# Conv3D is not recommended when using torch 2.9.x
|
||||||
torch_version = _get_package_version("torch")
|
if is_torch_version_greater_than("2.9.0") and not is_torch_version_greater_than("2.10.0"):
|
||||||
if version.parse("2.9.0") <= torch_version < version.parse("2.10.0"):
|
if any(isinstance(m, torch.nn.Conv3d) for m in model.modules()):
|
||||||
if any(isinstance(m, nn.Conv3d) for m in model.modules()):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
"Unsupported torch version detected: torch 2.9.x with Conv3D. "
|
||||||
"This combination is known to cause severe performance regression. "
|
"This combination is known to cause severe performance regression. "
|
||||||
|
|||||||
@@ -87,7 +87,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
self.compute_loss_func = dft_loss_func
|
self.compute_loss_func = dft_loss_func
|
||||||
|
|
||||||
|
|
||||||
elif finetuning_args.use_eaft_loss:
|
elif finetuning_args.use_eaft_loss:
|
||||||
from ..trainer_utils import eaft_loss_func
|
from ..trainer_utils import eaft_loss_func
|
||||||
|
|
||||||
@@ -95,7 +94,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
if training_args.fp8 and hasattr(self, "accelerator"): # verify FP8 status after trainer initialization
|
||||||
verify_fp8_status(self.accelerator, training_args)
|
verify_fp8_status(self.accelerator, training_args)
|
||||||
|
|
||||||
|
|||||||
@@ -634,7 +634,9 @@ def get_batch_logps(
|
|||||||
return logps, valid_length
|
return logps, valid_length
|
||||||
|
|
||||||
|
|
||||||
def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
def dft_loss_func(
|
||||||
|
outputs: "torch.Tensor", labels: "torch.Tensor", num_items_in_batch: Optional["torch.Tensor"] = None
|
||||||
|
):
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
if logits is None:
|
if logits is None:
|
||||||
return outputs.get("loss", torch.tensor(0.0))
|
return outputs.get("loss", torch.tensor(0.0))
|
||||||
@@ -652,11 +654,11 @@ def dft_loss_func(outputs, labels, num_items_in_batch=None):
|
|||||||
|
|
||||||
|
|
||||||
def _dft_cross_entropy(
|
def _dft_cross_entropy(
|
||||||
source: torch.Tensor,
|
source: "torch.Tensor",
|
||||||
target: torch.Tensor,
|
target: "torch.Tensor",
|
||||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
) -> torch.Tensor:
|
) -> "torch.Tensor":
|
||||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||||
valid_mask = target != ignore_index
|
valid_mask = target != ignore_index
|
||||||
if not valid_mask.any():
|
if not valid_mask.any():
|
||||||
@@ -679,7 +681,12 @@ def _dft_cross_entropy(
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
|
def eaft_loss_func(
|
||||||
|
outputs: "torch.Tensor",
|
||||||
|
labels: "torch.Tensor",
|
||||||
|
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||||
|
alpha: float = 1.0,
|
||||||
|
) -> "torch.Tensor":
|
||||||
logits = outputs.get("logits")
|
logits = outputs.get("logits")
|
||||||
if logits is None:
|
if logits is None:
|
||||||
return outputs.get("loss", torch.tensor(0.0))
|
return outputs.get("loss", torch.tensor(0.0))
|
||||||
@@ -697,12 +704,12 @@ def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
|
|||||||
|
|
||||||
|
|
||||||
def _eaft_cross_entropy(
|
def _eaft_cross_entropy(
|
||||||
source: torch.Tensor,
|
source: "torch.Tensor",
|
||||||
target: torch.Tensor,
|
target: "torch.Tensor",
|
||||||
num_items_in_batch: Optional[torch.Tensor] = None,
|
num_items_in_batch: Optional["torch.Tensor"] = None,
|
||||||
alpha: float = 1.0,
|
alpha: float = 1.0,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
) -> torch.Tensor:
|
) -> "torch.Tensor":
|
||||||
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
per_token_loss = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
|
||||||
valid_mask = target != ignore_index
|
valid_mask = target != ignore_index
|
||||||
if not valid_mask.any():
|
if not valid_mask.any():
|
||||||
@@ -712,13 +719,13 @@ def _eaft_cross_entropy(
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
source_detached = source[valid_mask].detach()
|
source_detached = source[valid_mask].detach()
|
||||||
|
|
||||||
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
|
topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
|
||||||
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
|
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
|
||||||
log_probs_topk = topk_val - logsumexp_topk
|
log_probs_topk = topk_val - logsumexp_topk
|
||||||
probs_topk = torch.exp(log_probs_topk)
|
probs_topk = torch.exp(log_probs_topk)
|
||||||
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
|
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)
|
||||||
|
|
||||||
entropy_term = entropy_approx / 3.0
|
entropy_term = entropy_approx / 3.0
|
||||||
adaptive_weight = torch.pow(entropy_term, alpha)
|
adaptive_weight = torch.pow(entropy_term, alpha)
|
||||||
|
|
||||||
@@ -731,6 +738,7 @@ def _eaft_cross_entropy(
|
|||||||
loss = total_loss / num_items_in_batch
|
loss = total_loss / num_items_in_batch
|
||||||
else:
|
else:
|
||||||
loss = weighted_losses.mean()
|
loss = weighted_losses.mean()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ class ModelArguments:
|
|||||||
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
metadata={"help": "Path to the model or model identifier from Hugging Face."},
|
||||||
)
|
)
|
||||||
template: str = field(
|
template: str = field(
|
||||||
default="chatml",
|
default="qwen3_nothink",
|
||||||
metadata={"help": "Template for the model."},
|
metadata={"help": "Template for the model."},
|
||||||
)
|
)
|
||||||
trust_remote_code: bool = field(
|
trust_remote_code: bool = field(
|
||||||
|
|||||||
@@ -12,38 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
|
|
||||||
from ...utils.constants import IGNORE_INDEX
|
from ...utils.constants import IGNORE_INDEX
|
||||||
from ...utils.helper import get_tokenizer
|
from ...utils.helper import get_tokenizer
|
||||||
from ...utils.types import Message, ModelInput, Processor
|
from ...utils.types import Message, ModelInput, Processor
|
||||||
|
|
||||||
|
|
||||||
def _update_model_input(
|
|
||||||
processor: Processor,
|
|
||||||
input_ids: list[int],
|
|
||||||
labels: list[int],
|
|
||||||
loss_weights: list[int],
|
|
||||||
temp_str: str,
|
|
||||||
temp_weight: float,
|
|
||||||
) -> str:
|
|
||||||
"""Update model input with temporary string."""
|
|
||||||
if not temp_str:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(processor)
|
|
||||||
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
|
||||||
input_ids.extend(temp_ids)
|
|
||||||
loss_weights.extend([temp_weight] * len(temp_ids))
|
|
||||||
if temp_weight > 1e-6:
|
|
||||||
labels.extend(temp_ids)
|
|
||||||
else:
|
|
||||||
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
|
||||||
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def render_chatml_messages(
|
def render_chatml_messages(
|
||||||
processor: Processor,
|
processor: Processor,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
@@ -52,123 +26,38 @@ def render_chatml_messages(
|
|||||||
) -> ModelInput:
|
) -> ModelInput:
|
||||||
"""Apply chatml template to messages and convert them to model input.
|
"""Apply chatml template to messages and convert them to model input.
|
||||||
|
|
||||||
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground
|
See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen2-7B-Instruct
|
||||||
"""
|
"""
|
||||||
|
tokenizer = get_tokenizer(processor)
|
||||||
input_ids, labels, loss_weights = [], [], []
|
input_ids, labels, loss_weights = [], [], []
|
||||||
temp_str, temp_weight = "", 0.0
|
|
||||||
if tools:
|
|
||||||
temp_str += "<|im_start|>system\n"
|
|
||||||
if messages[0]["role"] == "system":
|
|
||||||
for content in messages[0]["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "\n\n"
|
for message in messages:
|
||||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
temp_str = "<|im_start|>" + message["role"] + "\n"
|
||||||
|
for content in message["content"]:
|
||||||
temp_str += (
|
|
||||||
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
|
||||||
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
tools = json.loads(tools)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
|
||||||
|
|
||||||
if not isinstance(tools, list):
|
|
||||||
tools = [tools]
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
|
||||||
|
|
||||||
temp_str += (
|
|
||||||
"\n</tools>\n\nFor each function call, return a json object with function name "
|
|
||||||
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
|
||||||
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
|
||||||
)
|
|
||||||
elif messages[0]["role"] == "system":
|
|
||||||
temp_str += "<|im_start|>system\n"
|
|
||||||
for content in messages[0]["content"]:
|
|
||||||
if content["type"] == "text":
|
if content["type"] == "text":
|
||||||
temp_str += content["value"]
|
temp_str += content["value"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
temp_str += "<|im_end|>\n"
|
||||||
temp_weight = messages[0].get("loss_weight", 0.0)
|
temp_weight = message.get("loss_weight", 1.0 if message["role"] == "assistant" else 0.0)
|
||||||
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||||
for turn_idx, message in enumerate(messages):
|
if temp_weight > 1e-6:
|
||||||
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
labels.extend(temp_ids)
|
||||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
else:
|
||||||
for content in message["content"]:
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
temp_weight = message.get("loss_weight", 0.0)
|
|
||||||
elif message["role"] == "assistant":
|
|
||||||
temp_str += "<|im_start|>" + message["role"] + "\n"
|
|
||||||
for val_idx, content in enumerate(message["content"]):
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
elif content["type"] == "reasoning":
|
|
||||||
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
|
||||||
elif content["type"] == "tool_call":
|
|
||||||
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
|
||||||
temp_str += "\n"
|
|
||||||
|
|
||||||
try:
|
|
||||||
tool_call = json.loads(content["value"])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
|
||||||
temp_str += (
|
|
||||||
'<tool_call>\n{"name": "'
|
|
||||||
+ tool_call["name"]
|
|
||||||
+ '", "arguments": '
|
|
||||||
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
|
||||||
+ "}\n</tool_call>"
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
temp_weight = message.get("loss_weight", 1.0)
|
|
||||||
elif message["role"] == "tool":
|
|
||||||
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
|
||||||
temp_str += "<|im_start|>user"
|
|
||||||
|
|
||||||
temp_str += "\n<tool_response>\n"
|
|
||||||
for content in message["content"]:
|
|
||||||
if content["type"] == "text":
|
|
||||||
temp_str += content["value"]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported content type: {content['type']}")
|
|
||||||
|
|
||||||
temp_str += "\n</tool_response>"
|
|
||||||
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
|
||||||
temp_str += "<|im_end|>\n"
|
|
||||||
|
|
||||||
temp_weight = message.get("loss_weight", 0.0)
|
|
||||||
|
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
||||||
|
|
||||||
if is_generate:
|
if is_generate:
|
||||||
temp_str += "<|im_start|>assistant\n"
|
temp_ids = tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False)
|
||||||
temp_weight = 0.0
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([0.0] * len(temp_ids))
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
|
||||||
|
|
||||||
attention_mask = [1] * len(input_ids)
|
|
||||||
return ModelInput(
|
return ModelInput(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=[1] * len(input_ids),
|
||||||
labels=labels,
|
labels=labels,
|
||||||
loss_weights=loss_weights,
|
loss_weights=loss_weights,
|
||||||
)
|
)
|
||||||
@@ -183,36 +72,7 @@ def parse_chatml_message(generated_text: str) -> Message:
|
|||||||
Returns:
|
Returns:
|
||||||
Message: The parsed message.
|
Message: The parsed message.
|
||||||
"""
|
"""
|
||||||
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
return Message(role="assistant", content=[{"type": "text", "value": generated_text}])
|
||||||
content = []
|
|
||||||
last_end = 0
|
|
||||||
for match in pattern.finditer(generated_text):
|
|
||||||
start, end = match.span()
|
|
||||||
if start > last_end:
|
|
||||||
text = generated_text[last_end:start].strip()
|
|
||||||
if text:
|
|
||||||
content.append({"type": "text", "value": text})
|
|
||||||
|
|
||||||
tag_type = match.group(1)
|
|
||||||
tag_value = match.group(2).strip()
|
|
||||||
if tag_type == "thinking":
|
|
||||||
content.append({"type": "reasoning", "value": tag_value.strip()})
|
|
||||||
elif tag_type == "tool_call":
|
|
||||||
try:
|
|
||||||
json.loads(tag_value.strip())
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
|
||||||
|
|
||||||
content.append({"type": "tool_call", "value": tag_value.strip()})
|
|
||||||
|
|
||||||
last_end = end
|
|
||||||
|
|
||||||
if last_end < len(generated_text):
|
|
||||||
text = generated_text[last_end:].strip()
|
|
||||||
if text:
|
|
||||||
content.append({"type": "text", "value": text})
|
|
||||||
|
|
||||||
return Message(role="assistant", content=content)
|
|
||||||
|
|
||||||
|
|
||||||
class Renderer:
|
class Renderer:
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
return {"messages": messages, "extra_info": json.dumps({"tools": tools})}
|
return {"messages": messages, "tools": json.dumps(tools)}
|
||||||
else:
|
else:
|
||||||
return {"messages": messages}
|
return {"messages": messages}
|
||||||
|
|
||||||
|
|||||||
@@ -13,24 +13,200 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from ...utils.constants import IGNORE_INDEX
|
||||||
|
from ...utils.helper import get_tokenizer
|
||||||
from ...utils.plugin import BasePlugin
|
from ...utils.plugin import BasePlugin
|
||||||
from ...utils.types import Message, ModelInput, Processor
|
from ...utils.types import Message, ModelInput, Processor, ToolCall
|
||||||
|
|
||||||
|
|
||||||
class RenderingPlugin(BasePlugin):
|
class RenderingPlugin(BasePlugin):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@RenderingPlugin("qwen").register("render_messages")
|
def _update_model_input(
|
||||||
|
processor: Processor,
|
||||||
|
input_ids: list[int],
|
||||||
|
labels: list[int],
|
||||||
|
loss_weights: list[int],
|
||||||
|
temp_str: str,
|
||||||
|
temp_weight: float,
|
||||||
|
) -> str:
|
||||||
|
"""Update model input with temporary string."""
|
||||||
|
if not temp_str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tokenizer = get_tokenizer(processor)
|
||||||
|
temp_ids = tokenizer.encode(temp_str, add_special_tokens=False)
|
||||||
|
input_ids.extend(temp_ids)
|
||||||
|
loss_weights.extend([temp_weight] * len(temp_ids))
|
||||||
|
if temp_weight > 1e-6:
|
||||||
|
labels.extend(temp_ids)
|
||||||
|
else:
|
||||||
|
labels.extend([IGNORE_INDEX] * len(temp_ids))
|
||||||
|
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@RenderingPlugin("qwen3_nothink").register("render_messages")
|
||||||
def render_qwen_messages(
|
def render_qwen_messages(
|
||||||
processor: Processor,
|
processor: Processor,
|
||||||
messages: list[Message],
|
messages: list[Message],
|
||||||
tools: str | None = None,
|
tools: str | None = None,
|
||||||
is_generate: bool = False,
|
is_generate: bool = False,
|
||||||
) -> ModelInput:
|
) -> ModelInput:
|
||||||
raise NotImplementedError()
|
input_ids, labels, loss_weights = [], [], []
|
||||||
|
temp_str, temp_weight = "", 0.0
|
||||||
|
if tools:
|
||||||
|
temp_str += "<|im_start|>system\n"
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
for content in messages[0]["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "\n\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
tools = json.loads(tools)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tools format: {str(tools)}.")
|
||||||
|
|
||||||
|
if not isinstance(tools, list):
|
||||||
|
tools = [tools]
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
temp_str += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
"\n</tools>\n\nFor each function call, return a json object with function name "
|
||||||
|
'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": '
|
||||||
|
'<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n'
|
||||||
|
)
|
||||||
|
elif messages[0]["role"] == "system":
|
||||||
|
temp_str += "<|im_start|>system\n"
|
||||||
|
for content in messages[0]["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = messages[0].get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
for turn_idx, message in enumerate(messages):
|
||||||
|
if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0):
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
temp_str += "<|im_start|>" + message["role"] + "\n"
|
||||||
|
for val_idx, content in enumerate(message["content"]):
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
elif content["type"] == "reasoning":
|
||||||
|
temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" # avoid using special tokens
|
||||||
|
elif content["type"] == "tool_call":
|
||||||
|
if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]:
|
||||||
|
temp_str += "\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_call: ToolCall = json.loads(content["value"])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {content['value']}.")
|
||||||
|
|
||||||
|
temp_str += (
|
||||||
|
'<tool_call>\n{"name": "'
|
||||||
|
+ tool_call["name"]
|
||||||
|
+ '", "arguments": '
|
||||||
|
+ json.dumps(tool_call["arguments"], ensure_ascii=False)
|
||||||
|
+ "}\n</tool_call>"
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
temp_weight = message.get("loss_weight", 1.0)
|
||||||
|
elif message["role"] == "tool":
|
||||||
|
if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_start|>user"
|
||||||
|
|
||||||
|
temp_str += "\n<tool_response>\n"
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
temp_str += content["value"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported content type: {content['type']}")
|
||||||
|
|
||||||
|
temp_str += "\n</tool_response>"
|
||||||
|
if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool":
|
||||||
|
temp_str += "<|im_end|>\n"
|
||||||
|
|
||||||
|
temp_weight = message.get("loss_weight", 0.0)
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
if is_generate:
|
||||||
|
temp_str += "<|im_start|>assistant\n"
|
||||||
|
temp_weight = 0.0
|
||||||
|
|
||||||
|
temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight)
|
||||||
|
|
||||||
|
attention_mask = [1] * len(input_ids)
|
||||||
|
return ModelInput(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
labels=labels,
|
||||||
|
loss_weights=loss_weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@RenderingPlugin("qwen").register("parse_message")
|
@RenderingPlugin("qwen3_nothink").register("parse_message")
|
||||||
def parse_qwen_message(generated_text: str) -> Message:
|
def parse_qwen_message(generated_text: str) -> Message:
|
||||||
raise NotImplementedError()
|
pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL)
|
||||||
|
content = []
|
||||||
|
last_end = 0
|
||||||
|
for match in pattern.finditer(generated_text):
|
||||||
|
start, end = match.span()
|
||||||
|
if start > last_end:
|
||||||
|
text = generated_text[last_end:start].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
tag_type = match.group(1)
|
||||||
|
tag_value = match.group(2).strip()
|
||||||
|
if tag_type == "thinking":
|
||||||
|
content.append({"type": "reasoning", "value": tag_value.strip()})
|
||||||
|
elif tag_type == "tool_call":
|
||||||
|
try:
|
||||||
|
json.loads(tag_value.strip())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError(f"Invalid tool call format: {tag_value.strip()}.")
|
||||||
|
|
||||||
|
content.append({"type": "tool_call", "value": tag_value.strip()})
|
||||||
|
|
||||||
|
last_end = end
|
||||||
|
|
||||||
|
if last_end < len(generated_text):
|
||||||
|
text = generated_text[last_end:].strip()
|
||||||
|
if text:
|
||||||
|
content.append({"type": "text", "value": text})
|
||||||
|
|
||||||
|
return Message(role="assistant", content=content)
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
from typing import TYPE_CHECKING, Any, Literal, NotRequired, TypedDict, Union
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -102,8 +102,10 @@ class Message(TypedDict):
|
|||||||
class SFTSample(TypedDict):
|
class SFTSample(TypedDict):
|
||||||
messages: list[Message]
|
messages: list[Message]
|
||||||
"""Messages in the sample."""
|
"""Messages in the sample."""
|
||||||
|
tools: NotRequired[str]
|
||||||
|
"""Tools for the sample in JSON string format."""
|
||||||
extra_info: NotRequired[str]
|
extra_info: NotRequired[str]
|
||||||
"""Extra information for the sample, including tools, kto_labels."""
|
"""Extra information for the sample, e.g. kto_labels."""
|
||||||
_dataset_name: NotRequired[str]
|
_dataset_name: NotRequired[str]
|
||||||
"""Dataset name for the sample."""
|
"""Dataset name for the sample."""
|
||||||
|
|
||||||
@@ -113,8 +115,10 @@ class DPOSample(TypedDict):
|
|||||||
"""Chosen messages in the sample."""
|
"""Chosen messages in the sample."""
|
||||||
rejected_messages: list[Message]
|
rejected_messages: list[Message]
|
||||||
"""Rejected messages in the sample."""
|
"""Rejected messages in the sample."""
|
||||||
|
tools: NotRequired[str]
|
||||||
|
"""Tools for the sample in JSON string format."""
|
||||||
extra_info: NotRequired[str]
|
extra_info: NotRequired[str]
|
||||||
"""Extra information for the sample, including tools, kto_labels."""
|
"""Extra information for the sample, e.g. kto_labels."""
|
||||||
_dataset_name: NotRequired[str]
|
_dataset_name: NotRequired[str]
|
||||||
"""Dataset name for the sample."""
|
"""Dataset name for the sample."""
|
||||||
|
|
||||||
@@ -125,7 +129,7 @@ Sample = Union[SFTSample, DPOSample]
|
|||||||
class ToolCall(TypedDict):
|
class ToolCall(TypedDict):
|
||||||
name: str
|
name: str
|
||||||
"""Function name."""
|
"""Function name."""
|
||||||
arguments: str
|
arguments: dict[str, Any]
|
||||||
"""Function arguments."""
|
"""Function arguments."""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,2 @@
|
|||||||
# change if test fails or cache is outdated
|
# change if test fails or cache is outdated
|
||||||
0.9.5.101
|
0.9.5.103
|
||||||
|
|||||||
@@ -12,8 +12,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
from llamafactory.v1.config import DataArguments
|
||||||
|
from llamafactory.v1.core.data_engine import DataEngine
|
||||||
from llamafactory.v1.core.utils.rendering import Renderer
|
from llamafactory.v1.core.utils.rendering import Renderer
|
||||||
from llamafactory.v1.utils.types import Processor
|
from llamafactory.v1.utils.types import Processor
|
||||||
|
|
||||||
@@ -23,12 +28,54 @@ HF_MESSAGES = [
|
|||||||
{"role": "user", "content": "What is LLM?"},
|
{"role": "user", "content": "What is LLM?"},
|
||||||
{"role": "assistant", "content": "LLM stands for Large Language Model."},
|
{"role": "assistant", "content": "LLM stands for Large Language Model."},
|
||||||
]
|
]
|
||||||
|
|
||||||
V1_MESSAGES = [
|
V1_MESSAGES = [
|
||||||
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
|
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
|
||||||
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
|
{"role": "user", "content": [{"type": "text", "value": "What is LLM?"}]},
|
||||||
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
|
{"role": "assistant", "content": [{"type": "text", "value": "LLM stands for Large Language Model."}]},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
HF_MESSAGES_WITH_TOOLS = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "What is 6*8?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 6, "b": 8}}}],
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": "48."},
|
||||||
|
{"role": "assistant", "content": "The result of 6*8 is 48."},
|
||||||
|
]
|
||||||
|
|
||||||
|
V1_MESSAGES_WITH_TOOLS = [
|
||||||
|
{"role": "system", "content": [{"type": "text", "value": "You are a helpful assistant."}]},
|
||||||
|
{"role": "user", "content": [{"type": "text", "value": "What is 6*8?"}]},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})}],
|
||||||
|
"loss_weight": 0.0,
|
||||||
|
},
|
||||||
|
{"role": "tool", "content": [{"type": "text", "value": "48."}]},
|
||||||
|
{"role": "assistant", "content": [{"type": "text", "value": "The result of 6*8 is 48."}]},
|
||||||
|
]
|
||||||
|
|
||||||
|
V1_TOOLS = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "multiply",
|
||||||
|
"description": "A function that multiplies two numbers",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": {"type": "number", "description": "The first number to multiply"},
|
||||||
|
"b": {"type": "number", "description": "The second number to multiply"},
|
||||||
|
},
|
||||||
|
"required": ["a", "b"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_chatml_rendering():
|
def test_chatml_rendering():
|
||||||
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
@@ -60,6 +107,87 @@ def test_chatml_parse():
|
|||||||
assert parsed_message == V1_MESSAGES[-1]
|
assert parsed_message == V1_MESSAGES[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_samples", [16])
|
||||||
|
def test_chatml_rendering_remote(num_samples: int):
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("llamafactory/tiny-random-qwen3")
|
||||||
|
renderer = Renderer(template="chatml", processor=tokenizer)
|
||||||
|
data_args = DataArguments(dataset="llamafactory/v1-sft-demo")
|
||||||
|
data_engine = DataEngine(data_args)
|
||||||
|
for index in range(num_samples):
|
||||||
|
v1_inputs = renderer.render_messages(data_engine[index]["messages"], is_generate=True)
|
||||||
|
prefix = tokenizer.encode("<|im_start|>user\n", add_special_tokens=False)
|
||||||
|
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
|
||||||
|
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen3_nothink_rendering():
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||||
|
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||||
|
|
||||||
|
hf_inputs = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=True)
|
||||||
|
v1_inputs = renderer.render_messages(V1_MESSAGES_WITH_TOOLS[:-1], tools=json.dumps(V1_TOOLS), is_generate=True)
|
||||||
|
assert v1_inputs["input_ids"] == hf_inputs
|
||||||
|
assert v1_inputs["attention_mask"] == [1] * len(hf_inputs)
|
||||||
|
assert v1_inputs["labels"] == [-100] * len(hf_inputs)
|
||||||
|
assert v1_inputs["loss_weights"] == [0.0] * len(hf_inputs)
|
||||||
|
|
||||||
|
hf_inputs_part = tokenizer.apply_chat_template(
|
||||||
|
HF_MESSAGES_WITH_TOOLS[:-1], tools=V1_TOOLS, add_generation_prompt=False
|
||||||
|
)
|
||||||
|
hf_inputs_full = tokenizer.apply_chat_template(HF_MESSAGES_WITH_TOOLS, tools=V1_TOOLS, add_generation_prompt=False)
|
||||||
|
v1_inputs_full = renderer.render_messages(V1_MESSAGES_WITH_TOOLS, tools=json.dumps(V1_TOOLS), is_generate=False)
|
||||||
|
assert v1_inputs_full["input_ids"] == hf_inputs_full
|
||||||
|
assert v1_inputs_full["attention_mask"] == [1] * len(hf_inputs_full)
|
||||||
|
assert v1_inputs_full["labels"] == [-100] * len(hf_inputs_part) + hf_inputs_full[len(hf_inputs_part) :]
|
||||||
|
assert v1_inputs_full["loss_weights"] == [0.0] * len(hf_inputs_part) + [1.0] * (
|
||||||
|
len(hf_inputs_full) - len(hf_inputs_part)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_qwen3_nothink_parse():
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||||
|
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||||
|
generated_text = (
|
||||||
|
"<thinking>I need to use the multiply function to calculate 6*8.</thinking>"
|
||||||
|
"Let me call the multiply function."
|
||||||
|
'<tool_call>{"name": "multiply", "arguments": {"a": 6, "b": 8}}</tool_call>'
|
||||||
|
)
|
||||||
|
parsed_message = renderer.parse_message(generated_text)
|
||||||
|
assert parsed_message == {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "reasoning", "value": "I need to use the multiply function to calculate 6*8."},
|
||||||
|
{"type": "text", "value": "Let me call the multiply function."},
|
||||||
|
{"type": "tool_call", "value": json.dumps({"name": "multiply", "arguments": {"a": 6, "b": 8}})},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_samples", [8])
|
||||||
|
def test_qwen3_nothink_rendering_remote(num_samples: int):
|
||||||
|
tokenizer: Processor = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct-2507")
|
||||||
|
renderer = Renderer(template="qwen3_nothink", processor=tokenizer)
|
||||||
|
data_args = DataArguments(dataset="llamafactory/reason-tool-use-demo-1500")
|
||||||
|
data_engine = DataEngine(data_args)
|
||||||
|
for index in range(num_samples):
|
||||||
|
v1_inputs = renderer.render_messages(data_engine[index]["messages"], tools=data_engine[index]["tools"])
|
||||||
|
prefix_text = (
|
||||||
|
"<|im_start|>system\nYou are a methodical and expert assistant. "
|
||||||
|
"Your primary goal is to solve user requests by leveraging a set of available tools. "
|
||||||
|
"You must reason for the best course of action in a structured manner before responding.\n\n"
|
||||||
|
"# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
|
||||||
|
"You are provided with function signatures within <tools></tools> XML tags:\n<tools>\n"
|
||||||
|
'{"type": "function", "function": {"name":'
|
||||||
|
)
|
||||||
|
prefix = tokenizer.encode(prefix_text, add_special_tokens=False)
|
||||||
|
print(tokenizer.decode(v1_inputs["input_ids"][: len(prefix)]))
|
||||||
|
assert v1_inputs["input_ids"][: len(prefix)] == prefix
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_chatml_rendering()
|
test_chatml_rendering()
|
||||||
test_chatml_parse()
|
test_chatml_parse()
|
||||||
|
test_chatml_rendering_remote(16)
|
||||||
|
test_qwen3_nothink_rendering()
|
||||||
|
test_qwen3_nothink_parse()
|
||||||
|
test_qwen3_nothink_rendering_remote(16)
|
||||||
|
|||||||
@@ -61,11 +61,11 @@ def test_sharegpt_converter():
|
|||||||
}
|
}
|
||||||
expected_data = {
|
expected_data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{"content": [{"type": "text", "value": "System"}], "loss_weight": 0.0, "role": "system"},
|
{"role": "system", "content": [{"type": "text", "value": "System"}], "loss_weight": 0.0},
|
||||||
{"content": [{"type": "text", "value": "User"}], "loss_weight": 0.0, "role": "user"},
|
{"role": "user", "content": [{"type": "text", "value": "User"}], "loss_weight": 0.0},
|
||||||
{"content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0, "role": "assistant"},
|
{"role": "assistant", "content": [{"type": "tool_call", "value": "1"}], "loss_weight": 1.0},
|
||||||
{"content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0, "role": "tool"},
|
{"role": "tool", "content": [{"type": "text", "value": "Observation"}], "loss_weight": 0.0},
|
||||||
{"content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0, "role": "assistant"},
|
{"role": "assistant", "content": [{"type": "text", "value": "Assistant"}], "loss_weight": 1.0},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
assert DataConverterPlugin("sharegpt")(example) == expected_data
|
assert DataConverterPlugin("sharegpt")(example) == expected_data
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from llamafactory.v1.samplers.cli_sampler import SyncSampler
|
|||||||
|
|
||||||
@pytest.mark.runs_on(["cuda", "npu"])
|
@pytest.mark.runs_on(["cuda", "npu"])
|
||||||
def test_sync_sampler():
|
def test_sync_sampler():
|
||||||
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507")
|
model_args = ModelArguments(model="Qwen/Qwen3-4B-Instruct-2507", template="qwen3_nothink")
|
||||||
sample_args = SampleArguments()
|
sample_args = SampleArguments()
|
||||||
model_engine = ModelEngine(model_args)
|
model_engine = ModelEngine(model_args)
|
||||||
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
sampler = SyncSampler(sample_args, model_args, model_engine.model, model_engine.renderer)
|
||||||
|
|||||||
Reference in New Issue
Block a user