mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-09 21:25:59 +08:00
[v1] Support meta loading for full and free (#10236)
This commit is contained in:
@@ -150,6 +150,9 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
|
|||||||
|
|
||||||
@PeftPlugin("lora").register()
|
@PeftPlugin("lora").register()
|
||||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||||
|
if model.device.type == "meta":
|
||||||
|
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||||
|
|
||||||
adapter_name_or_path = config.get("adapter_name_or_path")
|
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||||
|
|
||||||
if adapter_name_or_path:
|
if adapter_name_or_path:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
|
|
||||||
@@ -212,10 +213,52 @@ class FSDP2Engine:
|
|||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def _save_non_persistent_buffers(self, model: HFModel) -> dict:
|
||||||
|
"""Save non-persistent buffers, such as inv_freq."""
|
||||||
|
saved = {}
|
||||||
|
for mod_name, module in model.named_modules():
|
||||||
|
for buf_name in module._non_persistent_buffers_set:
|
||||||
|
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||||
|
buf = getattr(module, buf_name, None)
|
||||||
|
if buf is not None:
|
||||||
|
saved[fqn] = copy.deepcopy(buf)
|
||||||
|
if self.rank == 0 and saved:
|
||||||
|
logger.info(f"Saved {len(saved)} non-persistent buffers")
|
||||||
|
return saved
|
||||||
|
|
||||||
|
def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict):
|
||||||
|
"""Register saved non-persistent buffers to model."""
|
||||||
|
if not saved_buffers:
|
||||||
|
return
|
||||||
|
device = get_current_accelerator()
|
||||||
|
for fqn, buf in saved_buffers.items():
|
||||||
|
buf = buf.to(device)
|
||||||
|
if "." in fqn:
|
||||||
|
parent_fqn, buf_name = fqn.rsplit(".", 1)
|
||||||
|
parent_module = model.get_submodule(parent_fqn)
|
||||||
|
else:
|
||||||
|
buf_name = fqn
|
||||||
|
parent_module = model
|
||||||
|
parent_module.register_buffer(buf_name, buf, persistent=False)
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
||||||
|
|
||||||
def shard_model(self, model: HFModel) -> HFModel:
|
def shard_model(self, model: HFModel) -> HFModel:
|
||||||
if model.device.type == "meta":
|
if model.device.type == "meta":
|
||||||
|
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
||||||
|
|
||||||
|
if getattr(model.config, "tie_word_embeddings", None):
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
model = self.prepare_model(model)
|
model = self.prepare_model(model)
|
||||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||||
|
|
||||||
|
# fix tied broken for no-fsdp-wrap case
|
||||||
|
if getattr(model.config, "tie_word_embeddings", None):
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model = self.prepare_model(model)
|
model = self.prepare_model(model)
|
||||||
return model
|
return model
|
||||||
|
|||||||
104
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
Normal file
104
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Unit tests: FSDP2 meta-device loading vs normal loading consistency.
|
||||||
|
|
||||||
|
Validates that the FSDP2 meta loading path behaves correctly for tied weights
|
||||||
|
and non-persistent buffers by comparing it with the standard non-meta path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||||
|
from llamafactory.v1.config.arg_parser import get_args
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
|
||||||
|
|
||||||
|
|
||||||
|
TINY_MODEL = "llamafactory/tiny-random-qwen3"
|
||||||
|
|
||||||
|
|
||||||
|
def collect_non_persistent_buffers(model):
|
||||||
|
"""Collect all non-persistent buffers from model."""
|
||||||
|
result = {}
|
||||||
|
for mod_name, module in model.named_modules():
|
||||||
|
for buf_name in getattr(module, "_non_persistent_buffers_set", set()):
|
||||||
|
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||||
|
buf = getattr(module, buf_name, None)
|
||||||
|
if buf is not None:
|
||||||
|
result[fqn] = buf.detach().cpu().clone()
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def test_fsdp2_meta_loading_buffers_and_tied_weights():
|
||||||
|
"""Verify non-persistent buffers and tied weights consistency after meta load."""
|
||||||
|
# 1. Initialize DistributedInterface for single process
|
||||||
|
DistributedInterface()
|
||||||
|
|
||||||
|
# 2. Build FSDP2Engine config
|
||||||
|
engine = FSDP2Engine(
|
||||||
|
{
|
||||||
|
"name": "fsdp2",
|
||||||
|
"mixed_precision": "bf16",
|
||||||
|
"reshard_after_forward": True,
|
||||||
|
"offload_params": False,
|
||||||
|
"pin_memory": False,
|
||||||
|
"dcp_path": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(TINY_MODEL)
|
||||||
|
|
||||||
|
# --- NORMAL PATH ---
|
||||||
|
normal_args, *_ = get_args(dict(model=TINY_MODEL, init_config=None))
|
||||||
|
normal_engine = ModelEngine(model_args=normal_args)
|
||||||
|
normal_model = normal_engine.model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
normal_model = engine.shard_model(normal_model)
|
||||||
|
normal_non_persistent = collect_non_persistent_buffers(normal_model)
|
||||||
|
|
||||||
|
del normal_model
|
||||||
|
|
||||||
|
# --- META PATH ---
|
||||||
|
meta_args, *_ = get_args(dict(model=TINY_MODEL, init_config={"name": "init_on_meta"}))
|
||||||
|
meta_model_engine = ModelEngine(model_args=meta_args)
|
||||||
|
meta_model = meta_model_engine.model
|
||||||
|
|
||||||
|
assert meta_model.device.type == "meta", "Model should be on meta device"
|
||||||
|
|
||||||
|
# Process meta device: save buffers -> tie_weights -> load from checkpoint -> restore buffers
|
||||||
|
meta_model = engine.shard_model(meta_model)
|
||||||
|
meta_non_persistent = collect_non_persistent_buffers(meta_model)
|
||||||
|
|
||||||
|
# 3. Tied weights (embed_tokens.weight and lm_head.weight)
|
||||||
|
|
||||||
|
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
|
||||||
|
if tie_word_embeddings:
|
||||||
|
assert meta_model.lm_head.weight is meta_model.model.embed_tokens.weight, (
|
||||||
|
"Weights should be tied after loading"
|
||||||
|
)
|
||||||
|
|
||||||
|
del meta_model
|
||||||
|
|
||||||
|
# 4. Non-persistent buffers (e.g., inv_freq)
|
||||||
|
normal_buf_keys = set(normal_non_persistent.keys())
|
||||||
|
meta_buf_keys = set(meta_non_persistent.keys())
|
||||||
|
assert normal_buf_keys == meta_buf_keys, "Non-persistent buffer keys mismatch"
|
||||||
|
|
||||||
|
for key in sorted(normal_buf_keys & meta_buf_keys):
|
||||||
|
nb = normal_non_persistent[key]
|
||||||
|
mb = meta_non_persistent[key]
|
||||||
|
assert nb.shape == mb.shape, f"Buffer shape mismatch: {key}"
|
||||||
|
assert torch.allclose(nb.float(), mb.float(), atol=1e-5), f"Buffer value mismatch: {key}"
|
||||||
Reference in New Issue
Block a user