3 Commits

Author SHA1 Message Date
jiaqiw09
9a0cfdccfa [v1] fix init on meta in transformers v5 (#10414) 2026-04-27 00:37:09 +08:00
Kingsley
c8890c32db [data] support discard history cot for multiturn (#10435) 2026-04-27 00:32:44 +08:00
Kingsley
79c8332e4c [train] add qwen35 patch for neat_packing (#10436) 2026-04-27 00:31:49 +08:00
7 changed files with 492 additions and 37 deletions

View File

@@ -471,8 +471,8 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
def __post_init__(self): def __post_init__(self):
super().__post_init__() super().__post_init__()
if self.neat_packing and self.attn_implementation == "flash_attention_2": if self.neat_packing and self.attn_implementation == "flash_attention_2":
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]: if self.model is not None and getattr(self.model.config, "model_type", None) in ["gemma4", "gpt_oss"]:
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.") raise ValueError("Neat packing is not supported for gemma4, gpt_oss models for now.")
@staticmethod @staticmethod
def _unpad_packed_features(features: dict[str, Any]) -> None: def _unpad_packed_features(features: dict[str, Any]) -> None:

View File

@@ -61,7 +61,8 @@ class SupervisedDatasetProcessor(DatasetProcessor):
input_ids, labels = self.template.mm_plugin.process_token_ids( input_ids, labels = self.template.mm_plugin.process_token_ids(
[], [], images, videos, audios, self.tokenizer, self.processor [], [], images, videos, audios, self.tokenizer, self.processor
) )
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools) discarding_history_cot = self.data_args.mask_history and not self.template.preserve_thinking
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools, discarding_history_cot)
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0) total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
if self.data_args.mask_history: if self.data_args.mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns encoded_pairs = encoded_pairs[::-1] # high priority for last turns

View File

@@ -79,6 +79,7 @@ class Template:
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
discarding_history_cot: bool = False, # only effect reasoning template
) -> list[tuple[list[int], list[int]]]: ) -> list[tuple[list[int], list[int]]]:
r"""Return multiple pairs of token ids representing prompts and responses respectively.""" r"""Return multiple pairs of token ids representing prompts and responses respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
@@ -441,14 +442,24 @@ class ReasoningTemplate(Template):
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
discarding_history_cot: bool = False,
) -> list[tuple[list[int], list[int]]]: ) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages) messages = deepcopy(messages)
if self.enable_thinking is False: # remove all cot if self.enable_thinking is False: # remove all cot
for i in range(1, len(messages), 2): for i in range(1, len(messages), 2):
messages[i]["content"] = self.remove_thought(messages[i]["content"]) messages[i]["content"] = self.remove_thought(messages[i]["content"])
if discarding_history_cot:
for i in range(1, len(messages) - 2, 2): # preserve the last cot
messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(0, len(messages), 2): if discarding_history_cot:
turn_indices = [len(messages) - 2]
else:
turn_indices = range(0, len(messages), 2)
for i in turn_indices:
if ( if (
self.thought_words[0].strip() not in messages[i + 1]["content"] self.thought_words[0].strip() not in messages[i + 1]["content"]
and self.thought_words[1].strip() not in messages[i + 1]["content"] and self.thought_words[1].strip() not in messages[i + 1]["content"]
@@ -2135,23 +2146,6 @@ register_template(
) )
# copied from qwen3_5_nothink template
register_template(
name="qwen3_6_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template( register_template(
name="sailor", name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),

View File

@@ -60,6 +60,191 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
def _check_fla_dependencies() -> None:
"""Check that the FLA dependencies required for varlen GDN forwarding are available.
Requires ``flash-linear-attention >= 0.4.1`` (which exposes the varlen
``causal_conv1d`` under ``fla.modules.convolution`` and the
``chunk_gated_delta_rule`` / ``fused_recurrent_gated_delta_rule`` kernels
under ``fla.ops.gated_delta_rule``). Raises ``ImportError`` with an
actionable message otherwise.
"""
try:
from fla.modules.convolution import causal_conv1d # noqa: F401
from fla.ops.gated_delta_rule import ( # noqa: F401
chunk_gated_delta_rule,
fused_recurrent_gated_delta_rule,
)
except ImportError as exc:
raise ImportError(
"Qwen3.5 packing-seq forwarding requires `flash-linear-attention>=0.4.1` "
"(provides `fla.modules.convolution.causal_conv1d` and "
"`fla.ops.gated_delta_rule.{chunk,fused_recurrent}_gated_delta_rule`). "
"Please install/upgrade it."
) from exc
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
"""
if is_transformers_version_greater_than("5.2.0"):
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
from torch.nn import functional as F
from transformers.modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
_check_fla_dependencies()
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
def _patched_decoder_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values=None,
cache_position: torch.LongTensor | None = None,
**kwargs,
) -> torch.FloatTensor:
"""Decoder layer forward that passes position_ids through to linear attention."""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.layer_type == "linear_attention":
hidden_states = self.linear_attn(
hidden_states=hidden_states,
cache_params=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
position_ids=position_ids, # passing position_ids to linear attention
)
elif self.layer_type == "full_attention":
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
hidden_states, _ = hidden_states
hidden_states = residual + hidden_states
return hidden_states
# gdn forward (training only, cache_params is always None)
def _patch_gdn_forward(
self,
hidden_states: torch.Tensor,
cache_params=None,
cache_position: torch.LongTensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
):
# @kuangdd fix: here attention_mask is None
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
batch_size, seq_len, _ = hidden_states.shape
# Qwen3.5 VL passes 3-D MRoPE position_ids ([axes, B, T]); collapse to [B, T].
if position_ids is not None and position_ids.ndim == 3:
position_ids = position_ids[0]
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
cu_seqlens = (
prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
if position_ids is not None
else None
)
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
# standard causal-conv1d path that the upstream forward uses.
mixed_qkv = self.in_proj_qkv(hidden_states)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
# FLA's causal_conv1d returns (out, final_state); we don't use the state here.
mixed_qkv, _ = fla_causal_conv1d(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
core_attn_out, _ = chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=True,
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
)
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
output = self.out_proj(core_attn_out)
return output
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeDecoderLayer,
Qwen3_5MoeGatedDeltaNet,
)
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
logger.info_rank0("Patched Qwen3.5 decoder forward to support cu_seqlens input only patch when do training.")
def patch_youtu_vl_model(model: "PreTrainedModel") -> None: def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
original_forward = model.forward original_forward = model.forward
@@ -232,6 +417,9 @@ def patch_model(
autocast_projector_dtype(model, model_args) autocast_projector_dtype(model, model_args)
add_z3_leaf_module(model) add_z3_leaf_module(model)
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
patch_qwen3_5_forward(model)
if not model_args.use_unsloth: if not model_args.use_unsloth:
print_attn_implementation(model.config) print_attn_implementation(model.config)

View File

@@ -17,7 +17,6 @@ import gc
import os import os
import torch import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
import torch.nn as nn import torch.nn as nn
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
@@ -43,6 +42,37 @@ from ....utils.types import HFModel, Processor
logger = get_logger(__name__) logger = get_logger(__name__)
def _fallback_dot_natural_key(name: str):
parts = []
for part in name.split("."):
if part.isdigit():
parts.append((0, int(part)))
else:
parts.append((1, part))
return parts
def _get_checkpoint_sort_key():
try:
from transformers.core_model_loading import dot_natural_key
return dot_natural_key
except ImportError:
return _fallback_dot_natural_key
def _make_safetensor_loader(checkpoint_file: str, tensor_key: str):
# Delay tensor materialization until converter.convert() to reduce peak CPU memory.
# This works because HF WeightConverter accepts callables and materializes them later.
def _load_tensor():
from safetensors import safe_open
with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
return f.get_tensor(tensor_key)
return _load_tensor
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None: def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None) no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules: if no_split_modules:
@@ -129,7 +159,7 @@ class FSDP2Engine:
self.offload_params = dist_config.get("offload_params", False) self.offload_params = dist_config.get("offload_params", False)
self.pin_memory = dist_config.get("pin_memory", True) self.pin_memory = dist_config.get("pin_memory", True)
self.dcp_path = dist_config.get("dcp_path", None) self.dcp_path = dist_config.get("dcp_path", None)
self.device_mesh = self.dist_interface.data_device_mesh self.device_mesh = self.dist_interface.model_device_mesh
if self.device_mesh is None: if self.device_mesh is None:
logger.warning( logger.warning(
@@ -303,9 +333,6 @@ class FSDP2Engine:
else: else:
full_sd = {} full_sd = {}
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
model = self.prepare_model(model) model = self.prepare_model(model)
device = get_current_accelerator() device = get_current_accelerator()
@@ -316,11 +343,6 @@ class FSDP2Engine:
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True) options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
set_model_state_dict(model, full_sd, options=options) set_model_state_dict(model, full_sd, options=options)
# Broadcast and restore non-persistent buffers
buffers_to_sync = [saved_buffers]
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
if self.rank == 0: if self.rank == 0:
logger.info("init_on_rank0 sync complete.") logger.info("init_on_rank0 sync complete.")
@@ -387,11 +409,37 @@ class FSDP2Engine:
logger.error(f"Failed to load from DCP: {e}") logger.error(f"Failed to load from DCP: {e}")
raise e raise e
def _try_build_hf_weight_conversion_context(self, model: HFModel) -> dict | None:
try:
from transformers.conversion_mapping import get_model_conversion_mapping
from transformers.core_model_loading import WeightConverter, WeightRenaming, rename_source_key
except ImportError:
return None
weight_mapping = get_model_conversion_mapping(model)
if not weight_mapping:
return None
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
return {
"prefix": getattr(model, "base_model_prefix", ""),
"meta_state_dict": model.state_dict(),
"rename_source_key": rename_source_key,
"renamings": renamings,
"converters": converters,
"converter_templates": {
pattern: converter for converter in converters for pattern in converter.source_patterns
},
"pending_converters": {},
}
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str): def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob import glob
import json import json
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path) hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
sort_key = _get_checkpoint_sort_key()
if self.rank == 0: if self.rank == 0:
logger.info(f"Loading weights from {hf_model_path} ...") logger.info(f"Loading weights from {hf_model_path} ...")
@@ -428,6 +476,7 @@ class FSDP2Engine:
raise ValueError(f"No checkpoint files found in {hf_model_path}") raise ValueError(f"No checkpoint files found in {hf_model_path}")
param_map = dict(model.named_parameters()) param_map = dict(model.named_parameters())
conversion_ctx = self._try_build_hf_weight_conversion_context(model)
total_files = len(checkpoint_files) total_files = len(checkpoint_files)
for i, ckpt_file in enumerate(checkpoint_files): for i, ckpt_file in enumerate(checkpoint_files):
@@ -438,18 +487,71 @@ class FSDP2Engine:
from safetensors import safe_open from safetensors import safe_open
with safe_open(ckpt_file, framework="pt", device="cpu") as f: with safe_open(ckpt_file, framework="pt", device="cpu") as f:
for key in f.keys(): for key in sorted(f.keys(), key=sort_key):
if key in param_map: renamed_key = key
source_pattern = None
if conversion_ctx is not None:
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
key,
conversion_ctx["renamings"],
conversion_ctx["converters"],
prefix=conversion_ctx["prefix"],
meta_state_dict=conversion_ctx["meta_state_dict"],
)
if source_pattern is not None:
template = conversion_ctx["converter_templates"][source_pattern]
converter = conversion_ctx["pending_converters"].setdefault(
renamed_key, copy.deepcopy(template)
)
converter.add_tensor(
renamed_key,
key,
source_pattern,
_make_safetensor_loader(ckpt_file, key),
)
elif renamed_key in param_map:
tensor = f.get_tensor(key) tensor = f.get_tensor(key)
self._copy_weights(param_map[key], tensor) self._copy_weights(param_map[renamed_key], tensor)
else: else:
state_dict = torch.load(ckpt_file, map_location="cpu") state_dict = torch.load(ckpt_file, map_location="cpu")
for key, tensor in state_dict.items(): for key, tensor in sorted(state_dict.items(), key=lambda item: sort_key(item[0])):
if key in param_map: renamed_key = key
self._copy_weights(param_map[key], tensor) source_pattern = None
if conversion_ctx is not None:
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
key,
conversion_ctx["renamings"],
conversion_ctx["converters"],
prefix=conversion_ctx["prefix"],
meta_state_dict=conversion_ctx["meta_state_dict"],
)
if source_pattern is not None:
template = conversion_ctx["converter_templates"][source_pattern]
converter = conversion_ctx["pending_converters"].setdefault(
renamed_key, copy.deepcopy(template)
)
converter.add_tensor(renamed_key, key, source_pattern, tensor)
elif renamed_key in param_map:
self._copy_weights(param_map[renamed_key], tensor)
del state_dict del state_dict
gc.collect() gc.collect()
if conversion_ctx is not None:
pending_count = len(conversion_ctx["pending_converters"])
log_fn = getattr(logger, "info_rank0", logger.info)
log_fn(f"Applying {pending_count} deferred HF weight conversions.")
for layer_name, converter in sorted(conversion_ctx["pending_converters"].items()):
realized_tensors = converter.convert(layer_name, model=model, config=model.config)
for target_name, tensor in realized_tensors.items():
if isinstance(tensor, list):
tensor = tensor[0]
if target_name in param_map:
self._copy_weights(param_map[target_name], tensor)
del realized_tensors
gc.collect()
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str: def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files. """Resolve a HF model identifier or local path to a local directory containing checkpoint files.

View File

@@ -181,6 +181,39 @@ def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool):
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2), (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
) )
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.parametrize("enable_thinking", [True, False, None])
@pytest.mark.parametrize("discarding_history_cot", [True, False])
def test_reasoning_encode_multiturn_discarding_history_cot(enable_thinking: bool, discarding_history_cot: bool):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
template = get_template_and_fix_tokenizer(tokenizer, data_args)
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT, discarding_history_cot=discarding_history_cot)
prompt_str_1 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
prompt_str_2 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
if enable_thinking is False:
answer_str_1 = f"{MESSAGES[1]['content']}<|im_end|>\n"
answer_str_2 = f"{MESSAGES[3]['content']}<|im_end|>\n"
if discarding_history_cot:
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
else:
prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
else:
if discarding_history_cot:
answer_str_1 = f"{MESSAGES[1]['content']}<|im_end|>\n"
else:
answer_str_1 = f"{MESSAGES_WITH_THOUGHT[1]['content']}<|im_end|>\n"
answer_str_2 = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
def test_jinja_template(): def test_jinja_template():

View File

@@ -0,0 +1,137 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
import torch
import torch.nn as nn
from safetensors.torch import save_file
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
NUM_EXPERTS = 11
HIDDEN_SIZE = 3
INTERMEDIATE_SIZE = 2
class Holder(nn.Module):
pass
class FakeFusedExpertsModel(nn.Module):
base_model_prefix = "model"
def __init__(self):
super().__init__()
self.config = types.SimpleNamespace(model_type="qwen3_moe")
self.model = Holder()
self.model.layers = nn.ModuleList([Holder()])
self.model.layers[0].mlp = Holder()
self.model.layers[0].mlp.experts = Holder()
self.model.layers[0].mlp.experts.gate_up_proj = nn.Parameter(
torch.zeros(NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
)
self.model.layers[0].mlp.experts.down_proj = nn.Parameter(
torch.zeros(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
)
class FakeLegacyExpert(nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False)
class FakeLegacyExpertsModel(nn.Module):
base_model_prefix = "model"
def __init__(self):
super().__init__()
self.config = types.SimpleNamespace(model_type="qwen3_moe")
self.model = Holder()
self.model.layers = nn.ModuleList([Holder()])
self.model.layers[0].mlp = Holder()
self.model.layers[0].mlp.experts = nn.ModuleList([FakeLegacyExpert() for _ in range(NUM_EXPERTS)])
def build_engine():
DistributedInterface()
return FSDP2Engine({"name": "fsdp2"})
def build_checkpoint():
ckpt = {}
gates, ups, downs = [], [], []
for i in range(NUM_EXPERTS):
# Use distinct values per expert so ordering bugs are easy to catch.
gate = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i))
up = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i) + 100.0)
down = torch.full((HIDDEN_SIZE, INTERMEDIATE_SIZE), float(i) + 200.0)
ckpt[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = gate
ckpt[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = up
ckpt[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = down
gates.append(gate)
ups.append(up)
downs.append(down)
return ckpt, gates, ups, downs
def test_fsdp2_gate_up_proj_loading(tmp_path):
engine = build_engine()
ckpt, gates, ups, downs = build_checkpoint()
save_file(ckpt, str(tmp_path / "model.safetensors"))
fused_model = FakeFusedExpertsModel()
conversion_ctx = engine._try_build_hf_weight_conversion_context(fused_model)
if conversion_ctx is not None:
# In transformers v5-style environments, legacy expert weights should be fused.
engine._load_weights_from_hf_checkpoint(fused_model, str(tmp_path))
expected_gate_up = torch.cat(
[torch.stack(gates, dim=0), torch.stack(ups, dim=0)],
dim=1,
)
expected_down = torch.stack(downs, dim=0)
experts = fused_model.model.layers[0].mlp.experts
assert torch.allclose(experts.gate_up_proj, expected_gate_up)
assert torch.allclose(experts.down_proj, expected_down)
# Check a double-digit expert index to ensure natural ordering is preserved.
assert torch.allclose(experts.gate_up_proj[2], expected_gate_up[2])
assert torch.allclose(experts.gate_up_proj[10], expected_gate_up[10])
assert torch.allclose(experts.down_proj[2], expected_down[2])
assert torch.allclose(experts.down_proj[10], expected_down[10])
else:
# In pre-v5 environments, the loader should fall back to direct copy.
legacy_model = FakeLegacyExpertsModel()
engine._load_weights_from_hf_checkpoint(legacy_model, str(tmp_path))
experts = legacy_model.model.layers[0].mlp.experts
for i in range(NUM_EXPERTS):
assert torch.allclose(experts[i].gate_proj.weight, gates[i])
assert torch.allclose(experts[i].up_proj.weight, ups[i])
assert torch.allclose(experts[i].down_proj.weight, downs[i])