[v1] fix init on meta in transformers v5 (#10414)

This commit is contained in:
jiaqiw09
2026-04-27 00:37:09 +08:00
committed by GitHub
parent c8890c32db
commit 9a0cfdccfa
2 changed files with 255 additions and 16 deletions

View File

@@ -17,7 +17,6 @@ import gc
import os import os
import torch import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp import torch.distributed.checkpoint as dcp
import torch.nn as nn import torch.nn as nn
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
@@ -43,6 +42,37 @@ from ....utils.types import HFModel, Processor
logger = get_logger(__name__) logger = get_logger(__name__)
def _fallback_dot_natural_key(name: str):
parts = []
for part in name.split("."):
if part.isdigit():
parts.append((0, int(part)))
else:
parts.append((1, part))
return parts
def _get_checkpoint_sort_key():
try:
from transformers.core_model_loading import dot_natural_key
return dot_natural_key
except ImportError:
return _fallback_dot_natural_key
def _make_safetensor_loader(checkpoint_file: str, tensor_key: str):
# Delay tensor materialization until converter.convert() to reduce peak CPU memory.
# This works because HF WeightConverter accepts callables and materializes them later.
def _load_tensor():
from safetensors import safe_open
with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
return f.get_tensor(tensor_key)
return _load_tensor
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None: def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None) no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules: if no_split_modules:
@@ -129,7 +159,7 @@ class FSDP2Engine:
self.offload_params = dist_config.get("offload_params", False) self.offload_params = dist_config.get("offload_params", False)
self.pin_memory = dist_config.get("pin_memory", True) self.pin_memory = dist_config.get("pin_memory", True)
self.dcp_path = dist_config.get("dcp_path", None) self.dcp_path = dist_config.get("dcp_path", None)
self.device_mesh = self.dist_interface.data_device_mesh self.device_mesh = self.dist_interface.model_device_mesh
if self.device_mesh is None: if self.device_mesh is None:
logger.warning( logger.warning(
@@ -303,9 +333,6 @@ class FSDP2Engine:
else: else:
full_sd = {} full_sd = {}
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
model = self.prepare_model(model) model = self.prepare_model(model)
device = get_current_accelerator() device = get_current_accelerator()
@@ -316,11 +343,6 @@ class FSDP2Engine:
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True) options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
set_model_state_dict(model, full_sd, options=options) set_model_state_dict(model, full_sd, options=options)
# Broadcast and restore non-persistent buffers
buffers_to_sync = [saved_buffers]
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
if self.rank == 0: if self.rank == 0:
logger.info("init_on_rank0 sync complete.") logger.info("init_on_rank0 sync complete.")
@@ -387,11 +409,37 @@ class FSDP2Engine:
logger.error(f"Failed to load from DCP: {e}") logger.error(f"Failed to load from DCP: {e}")
raise e raise e
def _try_build_hf_weight_conversion_context(self, model: HFModel) -> dict | None:
try:
from transformers.conversion_mapping import get_model_conversion_mapping
from transformers.core_model_loading import WeightConverter, WeightRenaming, rename_source_key
except ImportError:
return None
weight_mapping = get_model_conversion_mapping(model)
if not weight_mapping:
return None
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
return {
"prefix": getattr(model, "base_model_prefix", ""),
"meta_state_dict": model.state_dict(),
"rename_source_key": rename_source_key,
"renamings": renamings,
"converters": converters,
"converter_templates": {
pattern: converter for converter in converters for pattern in converter.source_patterns
},
"pending_converters": {},
}
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str): def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob import glob
import json import json
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path) hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
sort_key = _get_checkpoint_sort_key()
if self.rank == 0: if self.rank == 0:
logger.info(f"Loading weights from {hf_model_path} ...") logger.info(f"Loading weights from {hf_model_path} ...")
@@ -428,6 +476,7 @@ class FSDP2Engine:
raise ValueError(f"No checkpoint files found in {hf_model_path}") raise ValueError(f"No checkpoint files found in {hf_model_path}")
param_map = dict(model.named_parameters()) param_map = dict(model.named_parameters())
conversion_ctx = self._try_build_hf_weight_conversion_context(model)
total_files = len(checkpoint_files) total_files = len(checkpoint_files)
for i, ckpt_file in enumerate(checkpoint_files): for i, ckpt_file in enumerate(checkpoint_files):
@@ -438,18 +487,71 @@ class FSDP2Engine:
from safetensors import safe_open from safetensors import safe_open
with safe_open(ckpt_file, framework="pt", device="cpu") as f: with safe_open(ckpt_file, framework="pt", device="cpu") as f:
for key in f.keys(): for key in sorted(f.keys(), key=sort_key):
if key in param_map: renamed_key = key
source_pattern = None
if conversion_ctx is not None:
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
key,
conversion_ctx["renamings"],
conversion_ctx["converters"],
prefix=conversion_ctx["prefix"],
meta_state_dict=conversion_ctx["meta_state_dict"],
)
if source_pattern is not None:
template = conversion_ctx["converter_templates"][source_pattern]
converter = conversion_ctx["pending_converters"].setdefault(
renamed_key, copy.deepcopy(template)
)
converter.add_tensor(
renamed_key,
key,
source_pattern,
_make_safetensor_loader(ckpt_file, key),
)
elif renamed_key in param_map:
tensor = f.get_tensor(key) tensor = f.get_tensor(key)
self._copy_weights(param_map[key], tensor) self._copy_weights(param_map[renamed_key], tensor)
else: else:
state_dict = torch.load(ckpt_file, map_location="cpu") state_dict = torch.load(ckpt_file, map_location="cpu")
for key, tensor in state_dict.items(): for key, tensor in sorted(state_dict.items(), key=lambda item: sort_key(item[0])):
if key in param_map: renamed_key = key
self._copy_weights(param_map[key], tensor) source_pattern = None
if conversion_ctx is not None:
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
key,
conversion_ctx["renamings"],
conversion_ctx["converters"],
prefix=conversion_ctx["prefix"],
meta_state_dict=conversion_ctx["meta_state_dict"],
)
if source_pattern is not None:
template = conversion_ctx["converter_templates"][source_pattern]
converter = conversion_ctx["pending_converters"].setdefault(
renamed_key, copy.deepcopy(template)
)
converter.add_tensor(renamed_key, key, source_pattern, tensor)
elif renamed_key in param_map:
self._copy_weights(param_map[renamed_key], tensor)
del state_dict del state_dict
gc.collect() gc.collect()
if conversion_ctx is not None:
pending_count = len(conversion_ctx["pending_converters"])
log_fn = getattr(logger, "info_rank0", logger.info)
log_fn(f"Applying {pending_count} deferred HF weight conversions.")
for layer_name, converter in sorted(conversion_ctx["pending_converters"].items()):
realized_tensors = converter.convert(layer_name, model=model, config=model.config)
for target_name, tensor in realized_tensors.items():
if isinstance(tensor, list):
tensor = tensor[0]
if target_name in param_map:
self._copy_weights(param_map[target_name], tensor)
del realized_tensors
gc.collect()
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str: def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files. """Resolve a HF model identifier or local path to a local directory containing checkpoint files.

View File

@@ -0,0 +1,137 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
import torch
import torch.nn as nn
from safetensors.torch import save_file
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
NUM_EXPERTS = 11
HIDDEN_SIZE = 3
INTERMEDIATE_SIZE = 2
class Holder(nn.Module):
pass
class FakeFusedExpertsModel(nn.Module):
base_model_prefix = "model"
def __init__(self):
super().__init__()
self.config = types.SimpleNamespace(model_type="qwen3_moe")
self.model = Holder()
self.model.layers = nn.ModuleList([Holder()])
self.model.layers[0].mlp = Holder()
self.model.layers[0].mlp.experts = Holder()
self.model.layers[0].mlp.experts.gate_up_proj = nn.Parameter(
torch.zeros(NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
)
self.model.layers[0].mlp.experts.down_proj = nn.Parameter(
torch.zeros(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
)
class FakeLegacyExpert(nn.Module):
def __init__(self):
super().__init__()
self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False)
class FakeLegacyExpertsModel(nn.Module):
base_model_prefix = "model"
def __init__(self):
super().__init__()
self.config = types.SimpleNamespace(model_type="qwen3_moe")
self.model = Holder()
self.model.layers = nn.ModuleList([Holder()])
self.model.layers[0].mlp = Holder()
self.model.layers[0].mlp.experts = nn.ModuleList([FakeLegacyExpert() for _ in range(NUM_EXPERTS)])
def build_engine():
DistributedInterface()
return FSDP2Engine({"name": "fsdp2"})
def build_checkpoint():
ckpt = {}
gates, ups, downs = [], [], []
for i in range(NUM_EXPERTS):
# Use distinct values per expert so ordering bugs are easy to catch.
gate = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i))
up = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i) + 100.0)
down = torch.full((HIDDEN_SIZE, INTERMEDIATE_SIZE), float(i) + 200.0)
ckpt[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = gate
ckpt[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = up
ckpt[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = down
gates.append(gate)
ups.append(up)
downs.append(down)
return ckpt, gates, ups, downs
def test_fsdp2_gate_up_proj_loading(tmp_path):
engine = build_engine()
ckpt, gates, ups, downs = build_checkpoint()
save_file(ckpt, str(tmp_path / "model.safetensors"))
fused_model = FakeFusedExpertsModel()
conversion_ctx = engine._try_build_hf_weight_conversion_context(fused_model)
if conversion_ctx is not None:
# In transformers v5-style environments, legacy expert weights should be fused.
engine._load_weights_from_hf_checkpoint(fused_model, str(tmp_path))
expected_gate_up = torch.cat(
[torch.stack(gates, dim=0), torch.stack(ups, dim=0)],
dim=1,
)
expected_down = torch.stack(downs, dim=0)
experts = fused_model.model.layers[0].mlp.experts
assert torch.allclose(experts.gate_up_proj, expected_gate_up)
assert torch.allclose(experts.down_proj, expected_down)
# Check a double-digit expert index to ensure natural ordering is preserved.
assert torch.allclose(experts.gate_up_proj[2], expected_gate_up[2])
assert torch.allclose(experts.gate_up_proj[10], expected_gate_up[10])
assert torch.allclose(experts.down_proj[2], expected_down[2])
assert torch.allclose(experts.down_proj[10], expected_down[10])
else:
# In pre-v5 environments, the loader should fall back to direct copy.
legacy_model = FakeLegacyExpertsModel()
engine._load_weights_from_hf_checkpoint(legacy_model, str(tmp_path))
experts = legacy_model.model.layers[0].mlp.experts
for i in range(NUM_EXPERTS):
assert torch.allclose(experts[i].gate_proj.weight, gates[i])
assert torch.allclose(experts[i].up_proj.weight, ups[i])
assert torch.allclose(experts[i].down_proj.weight, downs[i])