mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
Compare commits
3 Commits
e0bc3c1971
...
9a0cfdccfa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9a0cfdccfa | ||
|
|
c8890c32db | ||
|
|
79c8332e4c |
@@ -471,8 +471,8 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.neat_packing and self.attn_implementation == "flash_attention_2":
|
||||
if self.model is not None and getattr(self.model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe", "gpt_oss"]:
|
||||
raise ValueError("Neat packing is not supported for qwen3_5, qwen3_5_moe, gpt_oss models for now.")
|
||||
if self.model is not None and getattr(self.model.config, "model_type", None) in ["gemma4", "gpt_oss"]:
|
||||
raise ValueError("Neat packing is not supported for gemma4, gpt_oss models for now.")
|
||||
|
||||
@staticmethod
|
||||
def _unpad_packed_features(features: dict[str, Any]) -> None:
|
||||
|
||||
@@ -61,7 +61,8 @@ class SupervisedDatasetProcessor(DatasetProcessor):
|
||||
input_ids, labels = self.template.mm_plugin.process_token_ids(
|
||||
[], [], images, videos, audios, self.tokenizer, self.processor
|
||||
)
|
||||
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools)
|
||||
discarding_history_cot = self.data_args.mask_history and not self.template.preserve_thinking
|
||||
encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools, discarding_history_cot)
|
||||
total_length = len(input_ids) + (1 if self.template.efficient_eos else 0)
|
||||
if self.data_args.mask_history:
|
||||
encoded_pairs = encoded_pairs[::-1] # high priority for last turns
|
||||
|
||||
@@ -79,6 +79,7 @@ class Template:
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
discarding_history_cot: bool = False, # only effect reasoning template
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
r"""Return multiple pairs of token ids representing prompts and responses respectively."""
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
@@ -441,14 +442,24 @@ class ReasoningTemplate(Template):
|
||||
messages: list[dict[str, str]],
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[str] = None,
|
||||
discarding_history_cot: bool = False,
|
||||
) -> list[tuple[list[int], list[int]]]:
|
||||
messages = deepcopy(messages)
|
||||
if self.enable_thinking is False: # remove all cot
|
||||
for i in range(1, len(messages), 2):
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
|
||||
if discarding_history_cot:
|
||||
for i in range(1, len(messages) - 2, 2): # preserve the last cot
|
||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||
|
||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||
for i in range(0, len(messages), 2):
|
||||
if discarding_history_cot:
|
||||
turn_indices = [len(messages) - 2]
|
||||
else:
|
||||
turn_indices = range(0, len(messages), 2)
|
||||
|
||||
for i in turn_indices:
|
||||
if (
|
||||
self.thought_words[0].strip() not in messages[i + 1]["content"]
|
||||
and self.thought_words[1].strip() not in messages[i + 1]["content"]
|
||||
@@ -2135,23 +2146,6 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
# copied from qwen3_5_nothink template
|
||||
register_template(
|
||||
name="qwen3_6_nothink",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen3_5"),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="sailor",
|
||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||
|
||||
@@ -60,6 +60,191 @@ def patch_qwen3_omni_moe_thinker_text_sparse_moe_block():
|
||||
modeling_qwen3_omni_moe.Qwen3OmniMoeThinkerTextSparseMoeBlock = Qwen3OmniMoeThinkerTextSparseMoeBlock
|
||||
|
||||
|
||||
def _check_fla_dependencies() -> None:
|
||||
"""Check that the FLA dependencies required for varlen GDN forwarding are available.
|
||||
|
||||
Requires ``flash-linear-attention >= 0.4.1`` (which exposes the varlen
|
||||
``causal_conv1d`` under ``fla.modules.convolution`` and the
|
||||
``chunk_gated_delta_rule`` / ``fused_recurrent_gated_delta_rule`` kernels
|
||||
under ``fla.ops.gated_delta_rule``). Raises ``ImportError`` with an
|
||||
actionable message otherwise.
|
||||
"""
|
||||
try:
|
||||
from fla.modules.convolution import causal_conv1d # noqa: F401
|
||||
from fla.ops.gated_delta_rule import ( # noqa: F401
|
||||
chunk_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Qwen3.5 packing-seq forwarding requires `flash-linear-attention>=0.4.1` "
|
||||
"(provides `fla.modules.convolution.causal_conv1d` and "
|
||||
"`fla.ops.gated_delta_rule.{chunk,fused_recurrent}_gated_delta_rule`). "
|
||||
"Please install/upgrade it."
|
||||
) from exc
|
||||
|
||||
|
||||
def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
||||
"""Patch the forward method of Qwen3_5ForConditionalGeneration to support cu_seqlens input only patch when do training.
|
||||
|
||||
Refer to: https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/models/qwen3_5/modeling.py.
|
||||
"""
|
||||
if is_transformers_version_greater_than("5.2.0"):
|
||||
from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states
|
||||
|
||||
from torch.nn import functional as F
|
||||
from transformers.modeling_flash_attention_utils import prepare_fa_kwargs_from_position_ids
|
||||
|
||||
_check_fla_dependencies()
|
||||
from fla.modules.convolution import causal_conv1d as fla_causal_conv1d
|
||||
from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
|
||||
def _patched_decoder_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_values=None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Decoder layer forward that passes position_ids through to linear attention."""
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
if self.layer_type == "linear_attention":
|
||||
hidden_states = self.linear_attn(
|
||||
hidden_states=hidden_states,
|
||||
cache_params=past_key_values,
|
||||
cache_position=cache_position,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids, # passing position_ids to linear attention
|
||||
)
|
||||
elif self.layer_type == "full_attention":
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if isinstance(hidden_states, tuple): # MoE returns (hidden_states, router_logits)
|
||||
hidden_states, _ = hidden_states
|
||||
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
# gdn forward (training only, cache_params is always None)
|
||||
def _patch_gdn_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache_params=None,
|
||||
cache_position: torch.LongTensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
):
|
||||
# @kuangdd fix: here attention_mask is None
|
||||
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
|
||||
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# Qwen3.5 VL passes 3-D MRoPE position_ids ([axes, B, T]); collapse to [B, T].
|
||||
if position_ids is not None and position_ids.ndim == 3:
|
||||
position_ids = position_ids[0]
|
||||
|
||||
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
|
||||
cu_seqlens = (
|
||||
prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
|
||||
if position_ids is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
||||
# standard causal-conv1d path that the upstream forward uses.
|
||||
mixed_qkv = self.in_proj_qkv(hidden_states)
|
||||
|
||||
z = self.in_proj_z(hidden_states)
|
||||
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
||||
|
||||
b = self.in_proj_b(hidden_states)
|
||||
a = self.in_proj_a(hidden_states)
|
||||
|
||||
# FLA's causal_conv1d returns (out, final_state); we don't use the state here.
|
||||
mixed_qkv, _ = fla_causal_conv1d(
|
||||
x=mixed_qkv,
|
||||
weight=self.conv1d.weight.squeeze(1),
|
||||
bias=self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
self.key_dim,
|
||||
self.key_dim,
|
||||
self.value_dim,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
||||
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
|
||||
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
|
||||
|
||||
beta = b.sigmoid()
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
|
||||
if self.num_v_heads // self.num_k_heads > 1:
|
||||
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
|
||||
|
||||
core_attn_out, _ = chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
**({"cu_seqlens": cu_seqlens} if cu_seqlens is not None else {}),
|
||||
)
|
||||
|
||||
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
|
||||
z = z.reshape(-1, self.head_v_dim)
|
||||
core_attn_out = self.norm(core_attn_out, z)
|
||||
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
|
||||
|
||||
output = self.out_proj(core_attn_out)
|
||||
|
||||
return output
|
||||
|
||||
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
|
||||
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
|
||||
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
|
||||
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
|
||||
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
|
||||
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
|
||||
Qwen3_5MoeDecoderLayer,
|
||||
Qwen3_5MoeGatedDeltaNet,
|
||||
)
|
||||
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
|
||||
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
|
||||
|
||||
logger.info_rank0("Patched Qwen3.5 decoder forward to support cu_seqlens input only patch when do training.")
|
||||
|
||||
|
||||
def patch_youtu_vl_model(model: "PreTrainedModel") -> None:
|
||||
original_forward = model.forward
|
||||
|
||||
@@ -232,6 +417,9 @@ def patch_model(
|
||||
autocast_projector_dtype(model, model_args)
|
||||
add_z3_leaf_module(model)
|
||||
|
||||
if getattr(model.config, "model_type", None) in ["qwen3_5", "qwen3_5_moe"] and model_args.flash_attn == "fa2":
|
||||
patch_qwen3_5_forward(model)
|
||||
|
||||
if not model_args.use_unsloth:
|
||||
print_attn_implementation(model.config)
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ import gc
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dcp
|
||||
import torch.nn as nn
|
||||
from peft.tuners.lora import LoraLayer
|
||||
@@ -43,6 +42,37 @@ from ....utils.types import HFModel, Processor
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _fallback_dot_natural_key(name: str):
|
||||
parts = []
|
||||
for part in name.split("."):
|
||||
if part.isdigit():
|
||||
parts.append((0, int(part)))
|
||||
else:
|
||||
parts.append((1, part))
|
||||
return parts
|
||||
|
||||
|
||||
def _get_checkpoint_sort_key():
|
||||
try:
|
||||
from transformers.core_model_loading import dot_natural_key
|
||||
|
||||
return dot_natural_key
|
||||
except ImportError:
|
||||
return _fallback_dot_natural_key
|
||||
|
||||
|
||||
def _make_safetensor_loader(checkpoint_file: str, tensor_key: str):
|
||||
# Delay tensor materialization until converter.convert() to reduce peak CPU memory.
|
||||
# This works because HF WeightConverter accepts callables and materializes them later.
|
||||
def _load_tensor():
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
|
||||
return f.get_tensor(tensor_key)
|
||||
|
||||
return _load_tensor
|
||||
|
||||
|
||||
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
|
||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||
if no_split_modules:
|
||||
@@ -129,7 +159,7 @@ class FSDP2Engine:
|
||||
self.offload_params = dist_config.get("offload_params", False)
|
||||
self.pin_memory = dist_config.get("pin_memory", True)
|
||||
self.dcp_path = dist_config.get("dcp_path", None)
|
||||
self.device_mesh = self.dist_interface.data_device_mesh
|
||||
self.device_mesh = self.dist_interface.model_device_mesh
|
||||
|
||||
if self.device_mesh is None:
|
||||
logger.warning(
|
||||
@@ -303,9 +333,6 @@ class FSDP2Engine:
|
||||
else:
|
||||
full_sd = {}
|
||||
|
||||
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
|
||||
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
|
||||
|
||||
model = self.prepare_model(model)
|
||||
|
||||
device = get_current_accelerator()
|
||||
@@ -316,11 +343,6 @@ class FSDP2Engine:
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
|
||||
set_model_state_dict(model, full_sd, options=options)
|
||||
|
||||
# Broadcast and restore non-persistent buffers
|
||||
buffers_to_sync = [saved_buffers]
|
||||
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
|
||||
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info("init_on_rank0 sync complete.")
|
||||
|
||||
@@ -387,11 +409,37 @@ class FSDP2Engine:
|
||||
logger.error(f"Failed to load from DCP: {e}")
|
||||
raise e
|
||||
|
||||
def _try_build_hf_weight_conversion_context(self, model: HFModel) -> dict | None:
|
||||
try:
|
||||
from transformers.conversion_mapping import get_model_conversion_mapping
|
||||
from transformers.core_model_loading import WeightConverter, WeightRenaming, rename_source_key
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
weight_mapping = get_model_conversion_mapping(model)
|
||||
if not weight_mapping:
|
||||
return None
|
||||
|
||||
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
|
||||
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
|
||||
return {
|
||||
"prefix": getattr(model, "base_model_prefix", ""),
|
||||
"meta_state_dict": model.state_dict(),
|
||||
"rename_source_key": rename_source_key,
|
||||
"renamings": renamings,
|
||||
"converters": converters,
|
||||
"converter_templates": {
|
||||
pattern: converter for converter in converters for pattern in converter.source_patterns
|
||||
},
|
||||
"pending_converters": {},
|
||||
}
|
||||
|
||||
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
|
||||
import glob
|
||||
import json
|
||||
|
||||
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
|
||||
sort_key = _get_checkpoint_sort_key()
|
||||
|
||||
if self.rank == 0:
|
||||
logger.info(f"Loading weights from {hf_model_path} ...")
|
||||
@@ -428,6 +476,7 @@ class FSDP2Engine:
|
||||
raise ValueError(f"No checkpoint files found in {hf_model_path}")
|
||||
|
||||
param_map = dict(model.named_parameters())
|
||||
conversion_ctx = self._try_build_hf_weight_conversion_context(model)
|
||||
total_files = len(checkpoint_files)
|
||||
|
||||
for i, ckpt_file in enumerate(checkpoint_files):
|
||||
@@ -438,18 +487,71 @@ class FSDP2Engine:
|
||||
from safetensors import safe_open
|
||||
|
||||
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key in param_map:
|
||||
for key in sorted(f.keys(), key=sort_key):
|
||||
renamed_key = key
|
||||
source_pattern = None
|
||||
if conversion_ctx is not None:
|
||||
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
|
||||
key,
|
||||
conversion_ctx["renamings"],
|
||||
conversion_ctx["converters"],
|
||||
prefix=conversion_ctx["prefix"],
|
||||
meta_state_dict=conversion_ctx["meta_state_dict"],
|
||||
)
|
||||
|
||||
if source_pattern is not None:
|
||||
template = conversion_ctx["converter_templates"][source_pattern]
|
||||
converter = conversion_ctx["pending_converters"].setdefault(
|
||||
renamed_key, copy.deepcopy(template)
|
||||
)
|
||||
converter.add_tensor(
|
||||
renamed_key,
|
||||
key,
|
||||
source_pattern,
|
||||
_make_safetensor_loader(ckpt_file, key),
|
||||
)
|
||||
elif renamed_key in param_map:
|
||||
tensor = f.get_tensor(key)
|
||||
self._copy_weights(param_map[key], tensor)
|
||||
self._copy_weights(param_map[renamed_key], tensor)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_file, map_location="cpu")
|
||||
for key, tensor in state_dict.items():
|
||||
if key in param_map:
|
||||
self._copy_weights(param_map[key], tensor)
|
||||
for key, tensor in sorted(state_dict.items(), key=lambda item: sort_key(item[0])):
|
||||
renamed_key = key
|
||||
source_pattern = None
|
||||
if conversion_ctx is not None:
|
||||
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
|
||||
key,
|
||||
conversion_ctx["renamings"],
|
||||
conversion_ctx["converters"],
|
||||
prefix=conversion_ctx["prefix"],
|
||||
meta_state_dict=conversion_ctx["meta_state_dict"],
|
||||
)
|
||||
|
||||
if source_pattern is not None:
|
||||
template = conversion_ctx["converter_templates"][source_pattern]
|
||||
converter = conversion_ctx["pending_converters"].setdefault(
|
||||
renamed_key, copy.deepcopy(template)
|
||||
)
|
||||
converter.add_tensor(renamed_key, key, source_pattern, tensor)
|
||||
elif renamed_key in param_map:
|
||||
self._copy_weights(param_map[renamed_key], tensor)
|
||||
del state_dict
|
||||
gc.collect()
|
||||
|
||||
if conversion_ctx is not None:
|
||||
pending_count = len(conversion_ctx["pending_converters"])
|
||||
log_fn = getattr(logger, "info_rank0", logger.info)
|
||||
log_fn(f"Applying {pending_count} deferred HF weight conversions.")
|
||||
for layer_name, converter in sorted(conversion_ctx["pending_converters"].items()):
|
||||
realized_tensors = converter.convert(layer_name, model=model, config=model.config)
|
||||
for target_name, tensor in realized_tensors.items():
|
||||
if isinstance(tensor, list):
|
||||
tensor = tensor[0]
|
||||
if target_name in param_map:
|
||||
self._copy_weights(param_map[target_name], tensor)
|
||||
del realized_tensors
|
||||
gc.collect()
|
||||
|
||||
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
|
||||
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.
|
||||
|
||||
|
||||
@@ -181,6 +181,39 @@ def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool):
|
||||
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
|
||||
)
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
||||
@pytest.mark.parametrize("discarding_history_cot", [True, False])
|
||||
def test_reasoning_encode_multiturn_discarding_history_cot(enable_thinking: bool, discarding_history_cot: bool):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT, discarding_history_cot=discarding_history_cot)
|
||||
|
||||
prompt_str_1 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
prompt_str_2 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
if enable_thinking is False:
|
||||
answer_str_1 = f"{MESSAGES[1]['content']}<|im_end|>\n"
|
||||
answer_str_2 = f"{MESSAGES[3]['content']}<|im_end|>\n"
|
||||
if discarding_history_cot:
|
||||
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
|
||||
else:
|
||||
prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
|
||||
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
|
||||
else:
|
||||
if discarding_history_cot:
|
||||
answer_str_1 = f"{MESSAGES[1]['content']}<|im_end|>\n"
|
||||
else:
|
||||
answer_str_1 = f"{MESSAGES_WITH_THOUGHT[1]['content']}<|im_end|>\n"
|
||||
answer_str_2 = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n"
|
||||
|
||||
_check_tokenization(
|
||||
tokenizer,
|
||||
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
|
||||
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cpu", "mps"])
|
||||
def test_jinja_template():
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import types
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
|
||||
|
||||
|
||||
NUM_EXPERTS = 11
|
||||
HIDDEN_SIZE = 3
|
||||
INTERMEDIATE_SIZE = 2
|
||||
|
||||
|
||||
class Holder(nn.Module):
|
||||
pass
|
||||
|
||||
|
||||
class FakeFusedExpertsModel(nn.Module):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = types.SimpleNamespace(model_type="qwen3_moe")
|
||||
|
||||
self.model = Holder()
|
||||
self.model.layers = nn.ModuleList([Holder()])
|
||||
self.model.layers[0].mlp = Holder()
|
||||
self.model.layers[0].mlp.experts = Holder()
|
||||
self.model.layers[0].mlp.experts.gate_up_proj = nn.Parameter(
|
||||
torch.zeros(NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
|
||||
)
|
||||
self.model.layers[0].mlp.experts.down_proj = nn.Parameter(
|
||||
torch.zeros(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
|
||||
)
|
||||
|
||||
|
||||
class FakeLegacyExpert(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
|
||||
self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
|
||||
self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False)
|
||||
|
||||
|
||||
class FakeLegacyExpertsModel(nn.Module):
|
||||
base_model_prefix = "model"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.config = types.SimpleNamespace(model_type="qwen3_moe")
|
||||
|
||||
self.model = Holder()
|
||||
self.model.layers = nn.ModuleList([Holder()])
|
||||
self.model.layers[0].mlp = Holder()
|
||||
self.model.layers[0].mlp.experts = nn.ModuleList([FakeLegacyExpert() for _ in range(NUM_EXPERTS)])
|
||||
|
||||
|
||||
def build_engine():
|
||||
DistributedInterface()
|
||||
return FSDP2Engine({"name": "fsdp2"})
|
||||
|
||||
|
||||
def build_checkpoint():
|
||||
ckpt = {}
|
||||
gates, ups, downs = [], [], []
|
||||
|
||||
for i in range(NUM_EXPERTS):
|
||||
# Use distinct values per expert so ordering bugs are easy to catch.
|
||||
gate = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i))
|
||||
up = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i) + 100.0)
|
||||
down = torch.full((HIDDEN_SIZE, INTERMEDIATE_SIZE), float(i) + 200.0)
|
||||
|
||||
ckpt[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = gate
|
||||
ckpt[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = up
|
||||
ckpt[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = down
|
||||
|
||||
gates.append(gate)
|
||||
ups.append(up)
|
||||
downs.append(down)
|
||||
|
||||
return ckpt, gates, ups, downs
|
||||
|
||||
|
||||
def test_fsdp2_gate_up_proj_loading(tmp_path):
|
||||
engine = build_engine()
|
||||
ckpt, gates, ups, downs = build_checkpoint()
|
||||
save_file(ckpt, str(tmp_path / "model.safetensors"))
|
||||
|
||||
fused_model = FakeFusedExpertsModel()
|
||||
conversion_ctx = engine._try_build_hf_weight_conversion_context(fused_model)
|
||||
|
||||
if conversion_ctx is not None:
|
||||
# In transformers v5-style environments, legacy expert weights should be fused.
|
||||
engine._load_weights_from_hf_checkpoint(fused_model, str(tmp_path))
|
||||
|
||||
expected_gate_up = torch.cat(
|
||||
[torch.stack(gates, dim=0), torch.stack(ups, dim=0)],
|
||||
dim=1,
|
||||
)
|
||||
expected_down = torch.stack(downs, dim=0)
|
||||
|
||||
experts = fused_model.model.layers[0].mlp.experts
|
||||
assert torch.allclose(experts.gate_up_proj, expected_gate_up)
|
||||
assert torch.allclose(experts.down_proj, expected_down)
|
||||
|
||||
# Check a double-digit expert index to ensure natural ordering is preserved.
|
||||
assert torch.allclose(experts.gate_up_proj[2], expected_gate_up[2])
|
||||
assert torch.allclose(experts.gate_up_proj[10], expected_gate_up[10])
|
||||
assert torch.allclose(experts.down_proj[2], expected_down[2])
|
||||
assert torch.allclose(experts.down_proj[10], expected_down[10])
|
||||
|
||||
else:
|
||||
# In pre-v5 environments, the loader should fall back to direct copy.
|
||||
legacy_model = FakeLegacyExpertsModel()
|
||||
engine._load_weights_from_hf_checkpoint(legacy_model, str(tmp_path))
|
||||
|
||||
experts = legacy_model.model.layers[0].mlp.experts
|
||||
for i in range(NUM_EXPERTS):
|
||||
assert torch.allclose(experts[i].gate_proj.weight, gates[i])
|
||||
assert torch.allclose(experts[i].up_proj.weight, ups[i])
|
||||
assert torch.allclose(experts[i].down_proj.weight, downs[i])
|
||||
Reference in New Issue
Block a user