From 0ee1c42c2be5b798609b5649b4a97ec1fc78e45c Mon Sep 17 00:00:00 2001 From: jiaqiw09 <60021713+jiaqiw09@users.noreply.github.com> Date: Thu, 5 Mar 2026 23:15:27 +0800 Subject: [PATCH] [v1] Support meta loading for full and free (#10236) --- .../v1/plugins/model_plugins/peft.py | 3 + .../trainer_plugins/distributed/fsdp2.py | 43 ++++++++ .../trainer_plugins/distributed/test_fsdp2.py | 104 ++++++++++++++++++ 3 files changed, 150 insertions(+) create mode 100644 tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py diff --git a/src/llamafactory/v1/plugins/model_plugins/peft.py b/src/llamafactory/v1/plugins/model_plugins/peft.py index 2ef2035e1..17ff3779e 100644 --- a/src/llamafactory/v1/plugins/model_plugins/peft.py +++ b/src/llamafactory/v1/plugins/model_plugins/peft.py @@ -150,6 +150,9 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is @PeftPlugin("lora").register() def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel: + if model.device.type == "meta": + raise ValueError("Currently lora stage does not support loading model by meta.") + adapter_name_or_path = config.get("adapter_name_or_path") if adapter_name_or_path: diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py index f32607627..7d4fac3cc 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/fsdp2.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import gc import os @@ -212,10 +213,52 @@ class FSDP2Engine: return model + def _save_non_persistent_buffers(self, model: HFModel) -> dict: + """Save non-persistent buffers, such as inv_freq.""" + saved = {} + for mod_name, module in model.named_modules(): + for buf_name in module._non_persistent_buffers_set: + fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name + buf = getattr(module, buf_name, None) + if buf is not None: + saved[fqn] = copy.deepcopy(buf) + if self.rank == 0 and saved: + logger.info(f"Saved {len(saved)} non-persistent buffers") + return saved + + def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict): + """Register saved non-persistent buffers to model.""" + if not saved_buffers: + return + device = get_current_accelerator() + for fqn, buf in saved_buffers.items(): + buf = buf.to(device) + if "." in fqn: + parent_fqn, buf_name = fqn.rsplit(".", 1) + parent_module = model.get_submodule(parent_fqn) + else: + buf_name = fqn + parent_module = model + parent_module.register_buffer(buf_name, buf, persistent=False) + if self.rank == 0: + logger.info(f"Restored {len(saved_buffers)} non-persistent buffers") + def shard_model(self, model: HFModel) -> HFModel: if model.device.type == "meta": + non_persistent_buffers = self._save_non_persistent_buffers(model) + + if getattr(model.config, "tie_word_embeddings", None): + model.tie_weights() + model = self.prepare_model(model) model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) + + # fix tied broken for no-fsdp-wrap case + if getattr(model.config, "tie_word_embeddings", None): + model.tie_weights() + + self._restore_non_persistent_buffers(model, non_persistent_buffers) + else: model = self.prepare_model(model) return model diff --git a/tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py b/tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py new file mode 100644 index 000000000..d1b0d7bf8 --- /dev/null +++ b/tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py @@ -0,0 +1,104 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests: FSDP2 meta-device loading vs normal loading consistency. + +Validates that the FSDP2 meta loading path behaves correctly for tied weights +and non-persistent buffers by comparing it with the standard non-meta path. +""" + +import torch +from transformers import AutoConfig + +from llamafactory.v1.accelerator.interface import DistributedInterface +from llamafactory.v1.config.arg_parser import get_args +from llamafactory.v1.core.model_engine import ModelEngine +from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine + + +TINY_MODEL = "llamafactory/tiny-random-qwen3" + + +def collect_non_persistent_buffers(model): + """Collect all non-persistent buffers from model.""" + result = {} + for mod_name, module in model.named_modules(): + for buf_name in getattr(module, "_non_persistent_buffers_set", set()): + fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name + buf = getattr(module, buf_name, None) + if buf is not None: + result[fqn] = buf.detach().cpu().clone() + return result + + +def test_fsdp2_meta_loading_buffers_and_tied_weights(): + """Verify non-persistent buffers and tied weights consistency after meta load.""" + # 1. Initialize DistributedInterface for single process + DistributedInterface() + + # 2. Build FSDP2Engine config + engine = FSDP2Engine( + { + "name": "fsdp2", + "mixed_precision": "bf16", + "reshard_after_forward": True, + "offload_params": False, + "pin_memory": False, + "dcp_path": None, + } + ) + + config = AutoConfig.from_pretrained(TINY_MODEL) + + # --- NORMAL PATH --- + normal_args, *_ = get_args(dict(model=TINY_MODEL, init_config=None)) + normal_engine = ModelEngine(model_args=normal_args) + normal_model = normal_engine.model.to(torch.bfloat16) + + normal_model = engine.shard_model(normal_model) + normal_non_persistent = collect_non_persistent_buffers(normal_model) + + del normal_model + + # --- META PATH --- + meta_args, *_ = get_args(dict(model=TINY_MODEL, init_config={"name": "init_on_meta"})) + meta_model_engine = ModelEngine(model_args=meta_args) + meta_model = meta_model_engine.model + + assert meta_model.device.type == "meta", "Model should be on meta device" + + # Process meta device: save buffers -> tie_weights -> load from checkpoint -> restore buffers + meta_model = engine.shard_model(meta_model) + meta_non_persistent = collect_non_persistent_buffers(meta_model) + + # 3. Tied weights (embed_tokens.weight and lm_head.weight) + + tie_word_embeddings = getattr(config, "tie_word_embeddings", False) + if tie_word_embeddings: + assert meta_model.lm_head.weight is meta_model.model.embed_tokens.weight, ( + "Weights should be tied after loading" + ) + + del meta_model + + # 4. Non-persistent buffers (e.g., inv_freq) + normal_buf_keys = set(normal_non_persistent.keys()) + meta_buf_keys = set(meta_non_persistent.keys()) + assert normal_buf_keys == meta_buf_keys, "Non-persistent buffer keys mismatch" + + for key in sorted(normal_buf_keys & meta_buf_keys): + nb = normal_non_persistent[key] + mb = meta_non_persistent[key] + assert nb.shape == mb.shape, f"Buffer shape mismatch: {key}" + assert torch.allclose(nb.float(), mb.float(), atol=1e-5), f"Buffer value mismatch: {key}"