mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-27 18:29:08 +08:00
[v1] fix init on meta in transformers v5 (#10414)
This commit is contained in:
@@ -17,7 +17,6 @@ import gc
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import torch.distributed.checkpoint as dcp
|
import torch.distributed.checkpoint as dcp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from peft.tuners.lora import LoraLayer
|
from peft.tuners.lora import LoraLayer
|
||||||
@@ -43,6 +42,37 @@ from ....utils.types import HFModel, Processor
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_dot_natural_key(name: str):
|
||||||
|
parts = []
|
||||||
|
for part in name.split("."):
|
||||||
|
if part.isdigit():
|
||||||
|
parts.append((0, int(part)))
|
||||||
|
else:
|
||||||
|
parts.append((1, part))
|
||||||
|
return parts
|
||||||
|
|
||||||
|
|
||||||
|
def _get_checkpoint_sort_key():
|
||||||
|
try:
|
||||||
|
from transformers.core_model_loading import dot_natural_key
|
||||||
|
|
||||||
|
return dot_natural_key
|
||||||
|
except ImportError:
|
||||||
|
return _fallback_dot_natural_key
|
||||||
|
|
||||||
|
|
||||||
|
def _make_safetensor_loader(checkpoint_file: str, tensor_key: str):
|
||||||
|
# Delay tensor materialization until converter.convert() to reduce peak CPU memory.
|
||||||
|
# This works because HF WeightConverter accepts callables and materializes them later.
|
||||||
|
def _load_tensor():
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
|
||||||
|
return f.get_tensor(tensor_key)
|
||||||
|
|
||||||
|
return _load_tensor
|
||||||
|
|
||||||
|
|
||||||
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
|
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
|
||||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||||
if no_split_modules:
|
if no_split_modules:
|
||||||
@@ -129,7 +159,7 @@ class FSDP2Engine:
|
|||||||
self.offload_params = dist_config.get("offload_params", False)
|
self.offload_params = dist_config.get("offload_params", False)
|
||||||
self.pin_memory = dist_config.get("pin_memory", True)
|
self.pin_memory = dist_config.get("pin_memory", True)
|
||||||
self.dcp_path = dist_config.get("dcp_path", None)
|
self.dcp_path = dist_config.get("dcp_path", None)
|
||||||
self.device_mesh = self.dist_interface.data_device_mesh
|
self.device_mesh = self.dist_interface.model_device_mesh
|
||||||
|
|
||||||
if self.device_mesh is None:
|
if self.device_mesh is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -303,9 +333,6 @@ class FSDP2Engine:
|
|||||||
else:
|
else:
|
||||||
full_sd = {}
|
full_sd = {}
|
||||||
|
|
||||||
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
|
|
||||||
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
|
|
||||||
|
|
||||||
model = self.prepare_model(model)
|
model = self.prepare_model(model)
|
||||||
|
|
||||||
device = get_current_accelerator()
|
device = get_current_accelerator()
|
||||||
@@ -316,11 +343,6 @@ class FSDP2Engine:
|
|||||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
|
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
|
||||||
set_model_state_dict(model, full_sd, options=options)
|
set_model_state_dict(model, full_sd, options=options)
|
||||||
|
|
||||||
# Broadcast and restore non-persistent buffers
|
|
||||||
buffers_to_sync = [saved_buffers]
|
|
||||||
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
|
|
||||||
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
|
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
logger.info("init_on_rank0 sync complete.")
|
logger.info("init_on_rank0 sync complete.")
|
||||||
|
|
||||||
@@ -387,11 +409,37 @@ class FSDP2Engine:
|
|||||||
logger.error(f"Failed to load from DCP: {e}")
|
logger.error(f"Failed to load from DCP: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def _try_build_hf_weight_conversion_context(self, model: HFModel) -> dict | None:
|
||||||
|
try:
|
||||||
|
from transformers.conversion_mapping import get_model_conversion_mapping
|
||||||
|
from transformers.core_model_loading import WeightConverter, WeightRenaming, rename_source_key
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
weight_mapping = get_model_conversion_mapping(model)
|
||||||
|
if not weight_mapping:
|
||||||
|
return None
|
||||||
|
|
||||||
|
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
|
||||||
|
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
|
||||||
|
return {
|
||||||
|
"prefix": getattr(model, "base_model_prefix", ""),
|
||||||
|
"meta_state_dict": model.state_dict(),
|
||||||
|
"rename_source_key": rename_source_key,
|
||||||
|
"renamings": renamings,
|
||||||
|
"converters": converters,
|
||||||
|
"converter_templates": {
|
||||||
|
pattern: converter for converter in converters for pattern in converter.source_patterns
|
||||||
|
},
|
||||||
|
"pending_converters": {},
|
||||||
|
}
|
||||||
|
|
||||||
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
|
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
|
||||||
import glob
|
import glob
|
||||||
import json
|
import json
|
||||||
|
|
||||||
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
|
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
|
||||||
|
sort_key = _get_checkpoint_sort_key()
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
logger.info(f"Loading weights from {hf_model_path} ...")
|
logger.info(f"Loading weights from {hf_model_path} ...")
|
||||||
@@ -428,6 +476,7 @@ class FSDP2Engine:
|
|||||||
raise ValueError(f"No checkpoint files found in {hf_model_path}")
|
raise ValueError(f"No checkpoint files found in {hf_model_path}")
|
||||||
|
|
||||||
param_map = dict(model.named_parameters())
|
param_map = dict(model.named_parameters())
|
||||||
|
conversion_ctx = self._try_build_hf_weight_conversion_context(model)
|
||||||
total_files = len(checkpoint_files)
|
total_files = len(checkpoint_files)
|
||||||
|
|
||||||
for i, ckpt_file in enumerate(checkpoint_files):
|
for i, ckpt_file in enumerate(checkpoint_files):
|
||||||
@@ -438,18 +487,71 @@ class FSDP2Engine:
|
|||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
|
|
||||||
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
|
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
|
||||||
for key in f.keys():
|
for key in sorted(f.keys(), key=sort_key):
|
||||||
if key in param_map:
|
renamed_key = key
|
||||||
|
source_pattern = None
|
||||||
|
if conversion_ctx is not None:
|
||||||
|
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
|
||||||
|
key,
|
||||||
|
conversion_ctx["renamings"],
|
||||||
|
conversion_ctx["converters"],
|
||||||
|
prefix=conversion_ctx["prefix"],
|
||||||
|
meta_state_dict=conversion_ctx["meta_state_dict"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_pattern is not None:
|
||||||
|
template = conversion_ctx["converter_templates"][source_pattern]
|
||||||
|
converter = conversion_ctx["pending_converters"].setdefault(
|
||||||
|
renamed_key, copy.deepcopy(template)
|
||||||
|
)
|
||||||
|
converter.add_tensor(
|
||||||
|
renamed_key,
|
||||||
|
key,
|
||||||
|
source_pattern,
|
||||||
|
_make_safetensor_loader(ckpt_file, key),
|
||||||
|
)
|
||||||
|
elif renamed_key in param_map:
|
||||||
tensor = f.get_tensor(key)
|
tensor = f.get_tensor(key)
|
||||||
self._copy_weights(param_map[key], tensor)
|
self._copy_weights(param_map[renamed_key], tensor)
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(ckpt_file, map_location="cpu")
|
state_dict = torch.load(ckpt_file, map_location="cpu")
|
||||||
for key, tensor in state_dict.items():
|
for key, tensor in sorted(state_dict.items(), key=lambda item: sort_key(item[0])):
|
||||||
if key in param_map:
|
renamed_key = key
|
||||||
self._copy_weights(param_map[key], tensor)
|
source_pattern = None
|
||||||
|
if conversion_ctx is not None:
|
||||||
|
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
|
||||||
|
key,
|
||||||
|
conversion_ctx["renamings"],
|
||||||
|
conversion_ctx["converters"],
|
||||||
|
prefix=conversion_ctx["prefix"],
|
||||||
|
meta_state_dict=conversion_ctx["meta_state_dict"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_pattern is not None:
|
||||||
|
template = conversion_ctx["converter_templates"][source_pattern]
|
||||||
|
converter = conversion_ctx["pending_converters"].setdefault(
|
||||||
|
renamed_key, copy.deepcopy(template)
|
||||||
|
)
|
||||||
|
converter.add_tensor(renamed_key, key, source_pattern, tensor)
|
||||||
|
elif renamed_key in param_map:
|
||||||
|
self._copy_weights(param_map[renamed_key], tensor)
|
||||||
del state_dict
|
del state_dict
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
if conversion_ctx is not None:
|
||||||
|
pending_count = len(conversion_ctx["pending_converters"])
|
||||||
|
log_fn = getattr(logger, "info_rank0", logger.info)
|
||||||
|
log_fn(f"Applying {pending_count} deferred HF weight conversions.")
|
||||||
|
for layer_name, converter in sorted(conversion_ctx["pending_converters"].items()):
|
||||||
|
realized_tensors = converter.convert(layer_name, model=model, config=model.config)
|
||||||
|
for target_name, tensor in realized_tensors.items():
|
||||||
|
if isinstance(tensor, list):
|
||||||
|
tensor = tensor[0]
|
||||||
|
if target_name in param_map:
|
||||||
|
self._copy_weights(param_map[target_name], tensor)
|
||||||
|
del realized_tensors
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
|
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
|
||||||
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.
|
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import types
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||||
|
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
|
||||||
|
|
||||||
|
|
||||||
|
NUM_EXPERTS = 11
|
||||||
|
HIDDEN_SIZE = 3
|
||||||
|
INTERMEDIATE_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
|
class Holder(nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FakeFusedExpertsModel(nn.Module):
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.config = types.SimpleNamespace(model_type="qwen3_moe")
|
||||||
|
|
||||||
|
self.model = Holder()
|
||||||
|
self.model.layers = nn.ModuleList([Holder()])
|
||||||
|
self.model.layers[0].mlp = Holder()
|
||||||
|
self.model.layers[0].mlp.experts = Holder()
|
||||||
|
self.model.layers[0].mlp.experts.gate_up_proj = nn.Parameter(
|
||||||
|
torch.zeros(NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE)
|
||||||
|
)
|
||||||
|
self.model.layers[0].mlp.experts.down_proj = nn.Parameter(
|
||||||
|
torch.zeros(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLegacyExpert(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
|
||||||
|
self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False)
|
||||||
|
self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeLegacyExpertsModel(nn.Module):
|
||||||
|
base_model_prefix = "model"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.config = types.SimpleNamespace(model_type="qwen3_moe")
|
||||||
|
|
||||||
|
self.model = Holder()
|
||||||
|
self.model.layers = nn.ModuleList([Holder()])
|
||||||
|
self.model.layers[0].mlp = Holder()
|
||||||
|
self.model.layers[0].mlp.experts = nn.ModuleList([FakeLegacyExpert() for _ in range(NUM_EXPERTS)])
|
||||||
|
|
||||||
|
|
||||||
|
def build_engine():
|
||||||
|
DistributedInterface()
|
||||||
|
return FSDP2Engine({"name": "fsdp2"})
|
||||||
|
|
||||||
|
|
||||||
|
def build_checkpoint():
|
||||||
|
ckpt = {}
|
||||||
|
gates, ups, downs = [], [], []
|
||||||
|
|
||||||
|
for i in range(NUM_EXPERTS):
|
||||||
|
# Use distinct values per expert so ordering bugs are easy to catch.
|
||||||
|
gate = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i))
|
||||||
|
up = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i) + 100.0)
|
||||||
|
down = torch.full((HIDDEN_SIZE, INTERMEDIATE_SIZE), float(i) + 200.0)
|
||||||
|
|
||||||
|
ckpt[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = gate
|
||||||
|
ckpt[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = up
|
||||||
|
ckpt[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = down
|
||||||
|
|
||||||
|
gates.append(gate)
|
||||||
|
ups.append(up)
|
||||||
|
downs.append(down)
|
||||||
|
|
||||||
|
return ckpt, gates, ups, downs
|
||||||
|
|
||||||
|
|
||||||
|
def test_fsdp2_gate_up_proj_loading(tmp_path):
|
||||||
|
engine = build_engine()
|
||||||
|
ckpt, gates, ups, downs = build_checkpoint()
|
||||||
|
save_file(ckpt, str(tmp_path / "model.safetensors"))
|
||||||
|
|
||||||
|
fused_model = FakeFusedExpertsModel()
|
||||||
|
conversion_ctx = engine._try_build_hf_weight_conversion_context(fused_model)
|
||||||
|
|
||||||
|
if conversion_ctx is not None:
|
||||||
|
# In transformers v5-style environments, legacy expert weights should be fused.
|
||||||
|
engine._load_weights_from_hf_checkpoint(fused_model, str(tmp_path))
|
||||||
|
|
||||||
|
expected_gate_up = torch.cat(
|
||||||
|
[torch.stack(gates, dim=0), torch.stack(ups, dim=0)],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
expected_down = torch.stack(downs, dim=0)
|
||||||
|
|
||||||
|
experts = fused_model.model.layers[0].mlp.experts
|
||||||
|
assert torch.allclose(experts.gate_up_proj, expected_gate_up)
|
||||||
|
assert torch.allclose(experts.down_proj, expected_down)
|
||||||
|
|
||||||
|
# Check a double-digit expert index to ensure natural ordering is preserved.
|
||||||
|
assert torch.allclose(experts.gate_up_proj[2], expected_gate_up[2])
|
||||||
|
assert torch.allclose(experts.gate_up_proj[10], expected_gate_up[10])
|
||||||
|
assert torch.allclose(experts.down_proj[2], expected_down[2])
|
||||||
|
assert torch.allclose(experts.down_proj[10], expected_down[10])
|
||||||
|
|
||||||
|
else:
|
||||||
|
# In pre-v5 environments, the loader should fall back to direct copy.
|
||||||
|
legacy_model = FakeLegacyExpertsModel()
|
||||||
|
engine._load_weights_from_hf_checkpoint(legacy_model, str(tmp_path))
|
||||||
|
|
||||||
|
experts = legacy_model.model.layers[0].mlp.experts
|
||||||
|
for i in range(NUM_EXPERTS):
|
||||||
|
assert torch.allclose(experts[i].gate_proj.weight, gates[i])
|
||||||
|
assert torch.allclose(experts[i].up_proj.weight, ups[i])
|
||||||
|
assert torch.allclose(experts[i].down_proj.weight, downs[i])
|
||||||
Reference in New Issue
Block a user