[v1] fix init on meta in transformers v5 (#10414)

This commit is contained in:
jiaqiw09
2026-04-27 00:37:09 +08:00
committed by GitHub
parent c8890c32db
commit 9a0cfdccfa
2 changed files with 255 additions and 16 deletions

View File

@@ -17,7 +17,6 @@ import gc
import os
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from peft.tuners.lora import LoraLayer
@@ -43,6 +42,37 @@ from ....utils.types import HFModel, Processor
logger = get_logger(__name__)
def _fallback_dot_natural_key(name: str):
parts = []
for part in name.split("."):
if part.isdigit():
parts.append((0, int(part)))
else:
parts.append((1, part))
return parts
def _get_checkpoint_sort_key():
try:
from transformers.core_model_loading import dot_natural_key
return dot_natural_key
except ImportError:
return _fallback_dot_natural_key
def _make_safetensor_loader(checkpoint_file: str, tensor_key: str):
# Delay tensor materialization until converter.convert() to reduce peak CPU memory.
# This works because HF WeightConverter accepts callables and materializes them later.
def _load_tensor():
from safetensors import safe_open
with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
return f.get_tensor(tensor_key)
return _load_tensor
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
no_split_modules = getattr(model, "_no_split_modules", None)
if no_split_modules:
@@ -129,7 +159,7 @@ class FSDP2Engine:
self.offload_params = dist_config.get("offload_params", False)
self.pin_memory = dist_config.get("pin_memory", True)
self.dcp_path = dist_config.get("dcp_path", None)
self.device_mesh = self.dist_interface.data_device_mesh
self.device_mesh = self.dist_interface.model_device_mesh
if self.device_mesh is None:
logger.warning(
@@ -303,9 +333,6 @@ class FSDP2Engine:
else:
full_sd = {}
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
model = self.prepare_model(model)
device = get_current_accelerator()
@@ -316,11 +343,6 @@ class FSDP2Engine:
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
set_model_state_dict(model, full_sd, options=options)
# Broadcast and restore non-persistent buffers
buffers_to_sync = [saved_buffers]
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
if self.rank == 0:
logger.info("init_on_rank0 sync complete.")
@@ -387,11 +409,37 @@ class FSDP2Engine:
logger.error(f"Failed to load from DCP: {e}")
raise e
def _try_build_hf_weight_conversion_context(self, model: HFModel) -> dict | None:
try:
from transformers.conversion_mapping import get_model_conversion_mapping
from transformers.core_model_loading import WeightConverter, WeightRenaming, rename_source_key
except ImportError:
return None
weight_mapping = get_model_conversion_mapping(model)
if not weight_mapping:
return None
renamings = [entry for entry in weight_mapping if isinstance(entry, WeightRenaming)]
converters = [entry for entry in weight_mapping if isinstance(entry, WeightConverter)]
return {
"prefix": getattr(model, "base_model_prefix", ""),
"meta_state_dict": model.state_dict(),
"rename_source_key": rename_source_key,
"renamings": renamings,
"converters": converters,
"converter_templates": {
pattern: converter for converter in converters for pattern in converter.source_patterns
},
"pending_converters": {},
}
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
import glob
import json
hf_model_path = self._resolve_hf_checkpoint_dir(hf_model_path)
sort_key = _get_checkpoint_sort_key()
if self.rank == 0:
logger.info(f"Loading weights from {hf_model_path} ...")
@@ -428,6 +476,7 @@ class FSDP2Engine:
raise ValueError(f"No checkpoint files found in {hf_model_path}")
param_map = dict(model.named_parameters())
conversion_ctx = self._try_build_hf_weight_conversion_context(model)
total_files = len(checkpoint_files)
for i, ckpt_file in enumerate(checkpoint_files):
@@ -438,18 +487,71 @@ class FSDP2Engine:
from safetensors import safe_open
with safe_open(ckpt_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key in param_map:
for key in sorted(f.keys(), key=sort_key):
renamed_key = key
source_pattern = None
if conversion_ctx is not None:
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
key,
conversion_ctx["renamings"],
conversion_ctx["converters"],
prefix=conversion_ctx["prefix"],
meta_state_dict=conversion_ctx["meta_state_dict"],
)
if source_pattern is not None:
template = conversion_ctx["converter_templates"][source_pattern]
converter = conversion_ctx["pending_converters"].setdefault(
renamed_key, copy.deepcopy(template)
)
converter.add_tensor(
renamed_key,
key,
source_pattern,
_make_safetensor_loader(ckpt_file, key),
)
elif renamed_key in param_map:
tensor = f.get_tensor(key)
self._copy_weights(param_map[key], tensor)
self._copy_weights(param_map[renamed_key], tensor)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")
for key, tensor in state_dict.items():
if key in param_map:
self._copy_weights(param_map[key], tensor)
for key, tensor in sorted(state_dict.items(), key=lambda item: sort_key(item[0])):
renamed_key = key
source_pattern = None
if conversion_ctx is not None:
renamed_key, source_pattern = conversion_ctx["rename_source_key"](
key,
conversion_ctx["renamings"],
conversion_ctx["converters"],
prefix=conversion_ctx["prefix"],
meta_state_dict=conversion_ctx["meta_state_dict"],
)
if source_pattern is not None:
template = conversion_ctx["converter_templates"][source_pattern]
converter = conversion_ctx["pending_converters"].setdefault(
renamed_key, copy.deepcopy(template)
)
converter.add_tensor(renamed_key, key, source_pattern, tensor)
elif renamed_key in param_map:
self._copy_weights(param_map[renamed_key], tensor)
del state_dict
gc.collect()
if conversion_ctx is not None:
pending_count = len(conversion_ctx["pending_converters"])
log_fn = getattr(logger, "info_rank0", logger.info)
log_fn(f"Applying {pending_count} deferred HF weight conversions.")
for layer_name, converter in sorted(conversion_ctx["pending_converters"].items()):
realized_tensors = converter.convert(layer_name, model=model, config=model.config)
for target_name, tensor in realized_tensors.items():
if isinstance(tensor, list):
tensor = tensor[0]
if target_name in param_map:
self._copy_weights(param_map[target_name], tensor)
del realized_tensors
gc.collect()
def _resolve_hf_checkpoint_dir(self, hf_model_path: str) -> str:
"""Resolve a HF model identifier or local path to a local directory containing checkpoint files.