[v1] fix init on meta in transformers v5 (#10414)

This commit is contained in:
jiaqiw09
2026-04-27 00:37:09 +08:00
committed by GitHub
parent c8890c32db
commit 9a0cfdccfa
2 changed files with 255 additions and 16 deletions

View File

@@ -17,7 +17,6 @@ import gc
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from peft.tuners.lora import LoraLayer
@@ -43,6 +42,37 @@ from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
def _fallback_dot_natural_key(name: str):
parts = []
for part in name.split("."):
if part.isdigit():
parts.append((0, int(part)))
else:
parts.append((1, part))
return parts
def _get_checkpoint_sort_key():
try:
from transformers.core_model_loading import dot_natural_key
return dot_natural_key
except ImportError:
return _fallback_dot_natural_key
def _make_safetensor_loader(checkpoint_file: str, tensor_key: str):
# Delay tensor materialization until converter.convert() to reduce peak CPU memory.
# This works because HF WeightConverter accepts callables and materializes them later.
def _load_tensor():
from safetensors import safe_open
with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
return f.get_tensor(tensor_key)
return _load_tensor
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules:
@@ -129,7 +159,7 @@ class FSDP2Engine:
self.offload_params = dist_config.get("offload_params", False)
self.pin_memory = dist_config.get("pin_memory", True)
self.dcp_path = dist_config.get("dcp_path", None)
self.device_mesh = self.dist_interface.data_device_mesh
self.device_mesh = self.dist_interface.model_device_mesh
if self.device_mesh is None:
logger.warning(
@@ -303,9 +333,6 @@ class FSDP2Engine:
else:
full_sd = {}
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
model = self.prepare_model(model)
device = get_current_accelerator()
@@ -316,11 +343,6 @@ class FSDP2Engine:
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
set_model_state_dict(model, full_sd, options=options)
# Broadcast and restore non-persistent buffers
buffers_to_sync = [saved_buffers]
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
if self.rank == 0:
logger.info("init_on_rank0 sync complete.")
@@ -387,11 +409,37 @@ class FSDP2Engine:
logger.error(f"Failed to load from DCP: {e}")
raise e
def _try_build_hf_weight_conversion_context(self, model: HFModel) -> dict | None:
try:
from transformers.conversion_mapping import get_model_conversion_mapping
from transformers.core_model_loading import WeightConverter, WeightRenaming, rename_source_key
except ImportError:
return None
weight_mapping = get_model_conversion_mapping(model)
if not weight_mapping:
return None
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
return {
"prefix": getattr(model, "base_model_prefix", ""),
"meta_state_dict": model.state_dict(),
"rename_source_key": rename_source_key,
"renamings": renamings,
"converters": converters,
"converter_templates": {
pattern: converter for converter in converters for pattern in converter.source_patterns
},
"pending_converters": {},
}
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob
import json
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
sort_key = _get_checkpoint_sort_key()
if self.rank == 0:
logger.info(f"Loading weights from {hf_model_path} ...")
@@ -428,6 +476,7 @@ class FSDP2Engine:
raise ValueError(f"No checkpoint files found in {hf_model_path}")
param_map = dict(model.named_parameters())
conversion_ctx = self._try_build_hf_weight_conversion_context(model)
total_files = len(checkpoint_files)
for i, ckpt_file in enumerate(checkpoint_files):
@@ -438,18 +487,71 @@ class FSDP2Engine:
from safetensors import safe_open
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key in param_map:
for key in sorted(f.keys(), key=sort_key):
renamed_key = key
source_pattern = None
if conversion_ctx is not None:
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
key,
conversion_ctx["renamings"],
conversion_ctx["converters"],
prefix=conversion_ctx["prefix"],
meta_state_dict=conversion_ctx["meta_state_dict"],
)
if source_pattern is not None:
template = conversion_ctx["converter_templates"][source_pattern]
converter = conversion_ctx["pending_converters"].setdefault(
renamed_key, copy.deepcopy(template)
)
converter.add_tensor(
renamed_key,
key,
source_pattern,
_make_safetensor_loader(ckpt_file, key),
)
elif renamed_key in param_map:
tensor = f.get_tensor(key)
self._copy_weights(param_map[key], tensor)
self._copy_weights(param_map[renamed_key], tensor)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")
for key, tensor in state_dict.items():
if key in param_map:
self._copy_weights(param_map[key], tensor)
for key, tensor in sorted(state_dict.items(), key=lambda item: sort_key(item[0])):
renamed_key = key
source_pattern = None
if conversion_ctx is not None:
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
key,
conversion_ctx["renamings"],
conversion_ctx["converters"],
prefix=conversion_ctx["prefix"],
meta_state_dict=conversion_ctx["meta_state_dict"],
)
if source_pattern is not None:
template = conversion_ctx["converter_templates"][source_pattern]
converter = conversion_ctx["pending_converters"].setdefault(
renamed_key, copy.deepcopy(template)
)
converter.add_tensor(renamed_key, key, source_pattern, tensor)
elif renamed_key in param_map:
self._copy_weights(param_map[renamed_key], tensor)
del state_dict
gc.collect()
if conversion_ctx is not None:
pending_count = len(conversion_ctx["pending_converters"])
log_fn = getattr(logger, "info_rank0", logger.info)
log_fn(f"Applying {pending_count} deferred HF weight conversions.")
for layer_name, converter in sorted(conversion_ctx["pending_converters"].items()):
realized_tensors = converter.convert(layer_name, model=model, config=model.config)
for target_name, tensor in realized_tensors.items():
if isinstance(tensor, list):
tensor = tensor[0]
if target_name in param_map:
self._copy_weights(param_map[target_name], tensor)
del realized_tensors
gc.collect()
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.

View File

@@ -0,0 +1,137 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
import torch
import torch.nn as nn
from safetensors.torch import save_file
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
NUM_EXPERTS = 11
HIDDEN_SIZE = 3
INTERMEDIATE_SIZE = 2
class Holder(nn.Module):
pass
class FakeFusedExpertsModel(nn.Module):
base_model_prefix = "model"
def __init__(self):
super().__init__()
self.config = types.SimpleNamespace(model_type="qwen3_moe")
self.model = Holder()
self.model.layers = nn.ModuleList([Holder()])
self.model.layers[0].mlp = Holder()
self.model.layers[0].mlp.experts = Holder()
self.model.layers[0].mlp.experts.gate_up_proj = nn.Parameter(
torch.zeros(NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
)
self.model.layers[0].mlp.experts.down_proj = nn.Parameter(
torch.zeros(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
)
class FakeLegacyExpert(nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False)
class FakeLegacyExpertsModel(nn.Module):
base_model_prefix = "model"
def __init__(self):
super().__init__()
self.config = types.SimpleNamespace(model_type="qwen3_moe")
self.model = Holder()
self.model.layers = nn.ModuleList([Holder()])
self.model.layers[0].mlp = Holder()
self.model.layers[0].mlp.experts = nn.ModuleList([FakeLegacyExpert() for _ in range(NUM_EXPERTS)])
def build_engine():
DistributedInterface()
return FSDP2Engine({"name": "fsdp2"})
def build_checkpoint():
ckpt = {}
gates, ups, downs = [], [], []
for i in range(NUM_EXPERTS):
# Use distinct values per expert so ordering bugs are easy to catch.
gate = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i))
up = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i) + 100.0)
down = torch.full((HIDDEN_SIZE, INTERMEDIATE_SIZE), float(i) + 200.0)
ckpt[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = gate
ckpt[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = up
ckpt[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = down
gates.append(gate)
ups.append(up)
downs.append(down)
return ckpt, gates, ups, downs
def test_fsdp2_gate_up_proj_loading(tmp_path):
engine = build_engine()
ckpt, gates, ups, downs = build_checkpoint()
save_file(ckpt, str(tmp_path / "model.safetensors"))
fused_model = FakeFusedExpertsModel()
conversion_ctx = engine._try_build_hf_weight_conversion_context(fused_model)
if conversion_ctx is not None:
# In transformers v5-style environments, legacy expert weights should be fused.
engine._load_weights_from_hf_checkpoint(fused_model, str(tmp_path))
expected_gate_up = torch.cat(
[torch.stack(gates, dim=0), torch.stack(ups, dim=0)],
dim=1,
)
expected_down = torch.stack(downs, dim=0)
experts = fused_model.model.layers[0].mlp.experts
assert torch.allclose(experts.gate_up_proj, expected_gate_up)
assert torch.allclose(experts.down_proj, expected_down)
# Check a double-digit expert index to ensure natural ordering is preserved.
assert torch.allclose(experts.gate_up_proj[2], expected_gate_up[2])
assert torch.allclose(experts.gate_up_proj[10], expected_gate_up[10])
assert torch.allclose(experts.down_proj[2], expected_down[2])
assert torch.allclose(experts.down_proj[10], expected_down[10])
else:
# In pre-v5 environments, the loader should fall back to direct copy.
legacy_model = FakeLegacyExpertsModel()
engine._load_weights_from_hf_checkpoint(legacy_model, str(tmp_path))
experts = legacy_model.model.layers[0].mlp.experts
for i in range(NUM_EXPERTS):
assert torch.allclose(experts[i].gate_proj.weight, gates[i])
assert torch.allclose(experts[i].up_proj.weight, ups[i])
assert torch.allclose(experts[i].down_proj.weight, downs[i])