mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
[misc] update license year & fix llama pro (#6814)
* fix llamapro script * change year
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||
# Copyright 2025 Tencent Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||
@@ -18,20 +18,15 @@
|
||||
import json
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from huggingface_hub import split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel
|
||||
from transformers.modeling_utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
shard_checkpoint,
|
||||
)
|
||||
from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -46,41 +41,36 @@ def block_expansion(
|
||||
model_name_or_path: str,
|
||||
output_dir: str,
|
||||
num_expand: int,
|
||||
shard_size: str = "2GB",
|
||||
shard_size: str = "5GB",
|
||||
save_safetensors: bool = True,
|
||||
):
|
||||
r"""
|
||||
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
|
||||
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
|
||||
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
|
||||
"""
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path)
|
||||
config: "PretrainedConfig" = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
num_layers = getattr(config, "num_hidden_layers")
|
||||
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||
config.save_pretrained(output_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name_or_path) # load the original one
|
||||
if save_safetensors:
|
||||
setattr(config, "tie_word_embeddings", False) # safetensors does not allow shared weights
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path,
|
||||
config=config,
|
||||
torch_dtype="auto",
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
assert isinstance(model, PreTrainedModel) # type hint
|
||||
state_dict = model.state_dict()
|
||||
|
||||
if num_layers % num_expand != 0:
|
||||
raise ValueError(f"`num_layers` {num_layers} should be divisible by `num_expand` {num_expand}.")
|
||||
|
||||
setattr(config, "num_hidden_layers", num_layers + num_expand)
|
||||
config.save_pretrained(output_dir)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
|
||||
print(f"Expanding model of {num_layers} layers to {num_layers + num_expand} layers.")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name_or_path, torch_dtype="auto", device_map="cpu", trust_remote_code=True, low_cpu_mem_usage=True
|
||||
)
|
||||
assert isinstance(model, PreTrainedModel) # type hint
|
||||
if save_safetensors and getattr(model.config, "tie_word_embeddings", False):
|
||||
del model.lm_head # safetensors does not allow shared weights
|
||||
|
||||
split = num_layers // num_expand
|
||||
layer_cnt = 0
|
||||
output_state_dict = OrderedDict()
|
||||
state_dict = model.state_dict()
|
||||
output_state_dict: Dict[str, "torch.Tensor"] = OrderedDict()
|
||||
for i in range(num_layers):
|
||||
for key, value in state_dict.items():
|
||||
if f".{i:d}." in key:
|
||||
@@ -104,17 +94,24 @@ def block_expansion(
|
||||
output_state_dict[key] = value
|
||||
|
||||
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
|
||||
shards, index = shard_checkpoint(output_state_dict, max_shard_size=shard_size, weights_name=weights_name)
|
||||
|
||||
for shard_file, shard in tqdm(shards.items(), desc="Save weights"):
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
output_state_dict, filename_pattern=filename_pattern, max_shard_size=shard_size
|
||||
)
|
||||
for shard_file, tensors in tqdm(state_dict_split.filename_to_tensors.items(), desc="Save weights"):
|
||||
shard = {tensor: output_state_dict[tensor].contiguous() for tensor in tensors}
|
||||
if save_safetensors:
|
||||
save_file(shard, os.path.join(output_dir, shard_file), metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
if not state_dict_split.is_sharded:
|
||||
print(f"Model weights saved in {os.path.join(output_dir, weights_name)}.")
|
||||
else:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
index_name = SAFE_WEIGHTS_INDEX_NAME if save_safetensors else WEIGHTS_INDEX_NAME
|
||||
with open(os.path.join(output_dir, index_name), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
|
||||
Reference in New Issue
Block a user