diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index 32b424443..88bdb4f4e 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -17,7 +17,6 @@ import gc import os import torch -import torch.distributed as dist import torch.distributed.checkpoint as dcp import torch.nn as nn from peft.tuners.lora import LoraLayer @@ -43,6 +42,37 @@ from ....utils.types import HFModel, Processor logger = get_logger(__name__) +def _fallback_dot_natural_key(name: str): + parts = [] + for part in name.split("."): + if part.isdigit(): + parts.append((0, int(part))) + else: + parts.append((1, part)) + return parts + + +def _get_checkpoint_sort_key(): + try: + from transformers.core_model_loading import dot_natural_key + + return dot_natural_key + except ImportError: + return _fallback_dot_natural_key + + +def _make_safetensor_loader(checkpoint_file: str, tensor_key: str): + # Delay tensor materialization until converter.convert() to reduce peak CPU memory. + # This works because HF WeightConverter accepts callables and materializes them later. + def _load_tensor(): + from safetensors import safe_open + + with safe_open(checkpoint_file, framework="pt", device="cpu") as f: + return f.get_tensor(tensor_key) + + return _load_tensor + + def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None: no_split_modules = getattr(model, "_no_split_modules", None) if no_split_modules: @@ -129,7 +159,7 @@ class FSDP2Engine: self.offload_params = dist_config.get("offload_params", False) self.pin_memory = dist_config.get("pin_memory", True) self.dcp_path = dist_config.get("dcp_path", None) - self.device_mesh = self.dist_interface.data_device_mesh + self.device_mesh = self.dist_interface.model_device_mesh if self.device_mesh is None: logger.warning( @@ -303,9 +333,6 @@ class FSDP2Engine: else: full_sd = {} - # Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard - saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {} - model = self.prepare_model(model) device = get_current_accelerator() @@ -316,11 +343,6 @@ class FSDP2Engine: options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True) set_model_state_dict(model, full_sd, options=options) - # Broadcast and restore non-persistent buffers - buffers_to_sync = [saved_buffers] - dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group()) - self._restore_non_persistent_buffers(model, buffers_to_sync[0]) - if self.rank == 0: logger.info("init_on_rank0 sync complete.") @@ -387,11 +409,37 @@ class FSDP2Engine: logger.error(f"Failed to load from DCP: {e}") raise e + def _try_build_hf_weight_conversion_context(self, model: HFModel) -> dict | None: + try: + from transformers.conversion_mapping import get_model_conversion_mapping + from transformers.core_model_loading import WeightConverter, WeightRenaming, rename_source_key + except ImportError: + return None + + weight_mapping = get_model_conversion_mapping(model) + if not weight_mapping: + return None + + renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)] + converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)] + return { + "prefix": getattr(model, "base_model_prefix", ""), + "meta_state_dict": model.state_dict(), + "rename_source_key": rename_source_key, + "renamings": renamings, + "converters": converters, + "converter_templates": { + pattern: converter for converter in converters for pattern in converter.source_patterns + }, + "pending_converters": {}, + } + def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str): import glob import json hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path) + sort_key = _get_checkpoint_sort_key() if self.rank == 0: logger.info(f"Loading weights from {hf_model_path} ...") @@ -428,6 +476,7 @@ class FSDP2Engine: raise ValueError(f"No checkpoint files found in {hf_model_path}") param_map = dict(model.named_parameters()) + conversion_ctx = self._try_build_hf_weight_conversion_context(model) total_files = len(checkpoint_files) for i, ckpt_file in enumerate(checkpoint_files): @@ -438,18 +487,71 @@ class FSDP2Engine: from safetensors import safe_open with safe_open(ckpt_file, framework="pt", device="cpu") as f: - for key in f.keys(): - if key in param_map: + for key in sorted(f.keys(), key=sort_key): + renamed_key = key + source_pattern = None + if conversion_ctx is not None: + renamed_key, source_pattern = conversion_ctx["rename_source_key"]( + key, + conversion_ctx["renamings"], + conversion_ctx["converters"], + prefix=conversion_ctx["prefix"], + meta_state_dict=conversion_ctx["meta_state_dict"], + ) + + if source_pattern is not None: + template = conversion_ctx["converter_templates"][source_pattern] + converter = conversion_ctx["pending_converters"].setdefault( + renamed_key, copy.deepcopy(template) + ) + converter.add_tensor( + renamed_key, + key, + source_pattern, + _make_safetensor_loader(ckpt_file, key), + ) + elif renamed_key in param_map: tensor = f.get_tensor(key) - self._copy_weights(param_map[key], tensor) + self._copy_weights(param_map[renamed_key], tensor) else: state_dict = torch.load(ckpt_file, map_location="cpu") - for key, tensor in state_dict.items(): - if key in param_map: - self._copy_weights(param_map[key], tensor) + for key, tensor in sorted(state_dict.items(), key=lambda item: sort_key(item[0])): + renamed_key = key + source_pattern = None + if conversion_ctx is not None: + renamed_key, source_pattern = conversion_ctx["rename_source_key"]( + key, + conversion_ctx["renamings"], + conversion_ctx["converters"], + prefix=conversion_ctx["prefix"], + meta_state_dict=conversion_ctx["meta_state_dict"], + ) + + if source_pattern is not None: + template = conversion_ctx["converter_templates"][source_pattern] + converter = conversion_ctx["pending_converters"].setdefault( + renamed_key, copy.deepcopy(template) + ) + converter.add_tensor(renamed_key, key, source_pattern, tensor) + elif renamed_key in param_map: + self._copy_weights(param_map[renamed_key], tensor) del state_dict gc.collect() + if conversion_ctx is not None: + pending_count = len(conversion_ctx["pending_converters"]) + log_fn = getattr(logger, "info_rank0", logger.info) + log_fn(f"Applying {pending_count} deferred HF weight conversions.") + for layer_name, converter in sorted(conversion_ctx["pending_converters"].items()): + realized_tensors = converter.convert(layer_name, model=model, config=model.config) + for target_name, tensor in realized_tensors.items(): + if isinstance(tensor, list): + tensor = tensor[0] + if target_name in param_map: + self._copy_weights(param_map[target_name], tensor) + del realized_tensors + gc.collect() + def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str: """Resolve a HF model identifier or local path to a local directory containing checkpoint files. diff --git a/tests_v1/plugins/trainer_plugins/distributed/test_fsdp2_weight_convert.py b/tests_v1/plugins/trainer_plugins/distributed/test_fsdp2_weight_convert.py new file mode 100644 index 000000000..c1bb94231 --- /dev/null +++ b/tests_v1/plugins/trainer_plugins/distributed/test_fsdp2_weight_convert.py @@ -0,0 +1,137 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types + +import torch +import torch.nn as nn +from safetensors.torch import save_file + +from llamafactory.v1.accelerator.interface import DistributedInterface +from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine + + +NUM_EXPERTS = 11 +HIDDEN_SIZE = 3 +INTERMEDIATE_SIZE = 2 + + +class Holder(nn.Module): + pass + + +class FakeFusedExpertsModel(nn.Module): + base_model_prefix = "model" + + def __init__(self): + super().__init__() + self.config = types.SimpleNamespace(model_type="qwen3_moe") + + self.model = Holder() + self.model.layers = nn.ModuleList([Holder()]) + self.model.layers[0].mlp = Holder() + self.model.layers[0].mlp.experts = Holder() + self.model.layers[0].mlp.experts.gate_up_proj = nn.Parameter( + torch.zeros(NUM_EXPERTS, 2 * INTERMEDIATE_SIZE, HIDDEN_SIZE) + ) + self.model.layers[0].mlp.experts.down_proj = nn.Parameter( + torch.zeros(NUM_EXPERTS, HIDDEN_SIZE, INTERMEDIATE_SIZE) + ) + + +class FakeLegacyExpert(nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False) + self.up_proj = nn.Linear(HIDDEN_SIZE, INTERMEDIATE_SIZE, bias=False) + self.down_proj = nn.Linear(INTERMEDIATE_SIZE, HIDDEN_SIZE, bias=False) + + +class FakeLegacyExpertsModel(nn.Module): + base_model_prefix = "model" + + def __init__(self): + super().__init__() + self.config = types.SimpleNamespace(model_type="qwen3_moe") + + self.model = Holder() + self.model.layers = nn.ModuleList([Holder()]) + self.model.layers[0].mlp = Holder() + self.model.layers[0].mlp.experts = nn.ModuleList([FakeLegacyExpert() for _ in range(NUM_EXPERTS)]) + + +def build_engine(): + DistributedInterface() + return FSDP2Engine({"name": "fsdp2"}) + + +def build_checkpoint(): + ckpt = {} + gates, ups, downs = [], [], [] + + for i in range(NUM_EXPERTS): + # Use distinct values per expert so ordering bugs are easy to catch. + gate = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i)) + up = torch.full((INTERMEDIATE_SIZE, HIDDEN_SIZE), float(i) + 100.0) + down = torch.full((HIDDEN_SIZE, INTERMEDIATE_SIZE), float(i) + 200.0) + + ckpt[f"model.layers.0.mlp.experts.{i}.gate_proj.weight"] = gate + ckpt[f"model.layers.0.mlp.experts.{i}.up_proj.weight"] = up + ckpt[f"model.layers.0.mlp.experts.{i}.down_proj.weight"] = down + + gates.append(gate) + ups.append(up) + downs.append(down) + + return ckpt, gates, ups, downs + + +def test_fsdp2_gate_up_proj_loading(tmp_path): + engine = build_engine() + ckpt, gates, ups, downs = build_checkpoint() + save_file(ckpt, str(tmp_path / "model.safetensors")) + + fused_model = FakeFusedExpertsModel() + conversion_ctx = engine._try_build_hf_weight_conversion_context(fused_model) + + if conversion_ctx is not None: + # In transformers v5-style environments, legacy expert weights should be fused. + engine._load_weights_from_hf_checkpoint(fused_model, str(tmp_path)) + + expected_gate_up = torch.cat( + [torch.stack(gates, dim=0), torch.stack(ups, dim=0)], + dim=1, + ) + expected_down = torch.stack(downs, dim=0) + + experts = fused_model.model.layers[0].mlp.experts + assert torch.allclose(experts.gate_up_proj, expected_gate_up) + assert torch.allclose(experts.down_proj, expected_down) + + # Check a double-digit expert index to ensure natural ordering is preserved. + assert torch.allclose(experts.gate_up_proj[2], expected_gate_up[2]) + assert torch.allclose(experts.gate_up_proj[10], expected_gate_up[10]) + assert torch.allclose(experts.down_proj[2], expected_down[2]) + assert torch.allclose(experts.down_proj[10], expected_down[10]) + + else: + # In pre-v5 environments, the loader should fall back to direct copy. + legacy_model = FakeLegacyExpertsModel() + engine._load_weights_from_hf_checkpoint(legacy_model, str(tmp_path)) + + experts = legacy_model.model.layers[0].mlp.experts + for i in range(NUM_EXPERTS): + assert torch.allclose(experts[i].gate_proj.weight, gates[i]) + assert torch.allclose(experts[i].up_proj.weight, ups[i]) + assert torch.allclose(experts[i].down_proj.weight, downs[i])