mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-16 09:05:59 +08:00
[v1] Support meta loading for full and free (#10236)
This commit is contained in:
@@ -150,6 +150,9 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
|
||||
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||
|
||||
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||
|
||||
if adapter_name_or_path:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
|
||||
@@ -212,10 +213,52 @@ class FSDP2Engine:
|
||||
|
||||
return model
|
||||
|
||||
def _save_non_persistent_buffers(self, model: HFModel) -> dict:
|
||||
"""Save non-persistent buffers, such as inv_freq."""
|
||||
saved = {}
|
||||
for mod_name, module in model.named_modules():
|
||||
for buf_name in module._non_persistent_buffers_set:
|
||||
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||
buf = getattr(module, buf_name, None)
|
||||
if buf is not None:
|
||||
saved[fqn] = copy.deepcopy(buf)
|
||||
if self.rank == 0 and saved:
|
||||
logger.info(f"Saved {len(saved)} non-persistent buffers")
|
||||
return saved
|
||||
|
||||
def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict):
|
||||
"""Register saved non-persistent buffers to model."""
|
||||
if not saved_buffers:
|
||||
return
|
||||
device = get_current_accelerator()
|
||||
for fqn, buf in saved_buffers.items():
|
||||
buf = buf.to(device)
|
||||
if "." in fqn:
|
||||
parent_fqn, buf_name = fqn.rsplit(".", 1)
|
||||
parent_module = model.get_submodule(parent_fqn)
|
||||
else:
|
||||
buf_name = fqn
|
||||
parent_module = model
|
||||
parent_module.register_buffer(buf_name, buf, persistent=False)
|
||||
if self.rank == 0:
|
||||
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
||||
|
||||
def shard_model(self, model: HFModel) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
||||
|
||||
if getattr(model.config, "tie_word_embeddings", None):
|
||||
model.tie_weights()
|
||||
|
||||
model = self.prepare_model(model)
|
||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||
|
||||
# fix tied broken for no-fsdp-wrap case
|
||||
if getattr(model.config, "tie_word_embeddings", None):
|
||||
model.tie_weights()
|
||||
|
||||
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
||||
|
||||
else:
|
||||
model = self.prepare_model(model)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user