diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 55237d88..fa0abb93 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -128,7 +128,7 @@ def load_model_and_tokenizer( else: logger.warning("Current model does not support RoPE scaling.") - # Set FlashAttention-2 + # Set FlashAttention-2 and S^2-Attn if model_args.flash_attn: if getattr(config, "model_type", None) == "llama": LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 diff --git a/tests/llamafy_baichuan2.py b/tests/llamafy_baichuan2.py index f7b07428..d08eee1c 100644 --- a/tests/llamafy_baichuan2.py +++ b/tests/llamafy_baichuan2.py @@ -1,6 +1,6 @@ # coding=utf-8 # Converts the Baichuan2-7B model in the same format as LLaMA2-7B. -# Usage: python llamafy_baichuan2.py --llama2_json llama2.index.json --input_dir input --output_dir output +# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB # Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py # Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied @@ -9,56 +9,77 @@ import fire import json import torch from collections import OrderedDict +from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from typing import Any, Dict -SHARD_A = "pytorch_model-00001-of-00002.bin" -SHARD_B = "pytorch_model-00002-of-00002.bin" +CONFIG_NAME = "config.json" -def llamafy_baichuan2( - llama2_json: str, +def save_weight( input_dir: str, - output_dir: str + output_dir: str, + shard_size: str ): - baichuan2_state_dict = OrderedDict() + baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict() for filepath in os.listdir(input_dir): if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"): shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu") baichuan2_state_dict.update(shard_weight) - llama2_state_dict = OrderedDict() - total_size = 0 + llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict() for key, value in baichuan2_state_dict.items(): - total_size += 2 * value.numel() # half precision if "W_pack" in key: - llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:4096, :] - llama2_state_dict[key.replace("W_pack", "k_proj")] = value[4096:2*4096, :] - llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*4096:, :] + proj_size = value.size(0) // 3 + llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :] + llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :] + llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :] elif "lm_head" in key: llama2_state_dict[key] = torch.nn.functional.normalize(value) else: llama2_state_dict[key] = value - with open(os.path.join(input_dir, llama2_json), "r", encoding="utf-8") as f: - llama2_index = json.load(f) + shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME) + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(output_dir, shard_file)) + + if index is None: + print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) + else: + with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + json.dump(index, f, indent=2, sort_keys=True) + print("Model weights saved in {}".format(output_dir)) - merged_index = OrderedDict() - merged_index["metadata"] = {"total_size": total_size} - merged_index["weight_map"] = llama2_index["weight_map"] - state_dict_a, state_dict_b = OrderedDict(), OrderedDict() - for key, value in llama2_state_dict.items(): - if merged_index["weight_map"][key] == SHARD_A: - state_dict_a[key] = value - else: - state_dict_b[key] = value +def save_config( + input_dir: str, + output_dir: str +): + with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: + llama2_config_dict: Dict[str, Any] = json.load(f) - os.makedirs(output_dir, exist_ok=True) - torch.save(state_dict_a, os.path.join(output_dir, SHARD_A)) - torch.save(state_dict_b, os.path.join(output_dir, SHARD_B)) - with open(os.path.join(output_dir, "pytorch_model.bin.index.json"), "w", encoding="utf-8") as f: - json.dump(merged_index, f, indent=2) - print("Completed!") + llama2_config_dict["architectures"] = ["LlamaForCausalLM"] + llama2_config_dict.pop("auto_map", None) + llama2_config_dict.pop("tokenizer_class", None) + llama2_config_dict["model_type"] = "llama" + + with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: + json.dump(llama2_config_dict, f, indent=2) + print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) + + +def llamafy_baichuan2( + input_dir: str, + output_dir: str, + shard_size: str +): + try: + os.makedirs(output_dir, exist_ok=False) + except Exception as e: + raise print("Output dir already exists", e) + + save_weight(input_dir, output_dir, shard_size) + save_config(input_dir, output_dir) if __name__ == "__main__": diff --git a/tests/llamafy_qwen.py b/tests/llamafy_qwen.py new file mode 100644 index 00000000..8b9fc395 --- /dev/null +++ b/tests/llamafy_qwen.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Converts the Qwen models in the same format as LLaMA2. +# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB + +import os +import fire +import json +import torch +from collections import OrderedDict +from safetensors import safe_open +from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers.utils import check_min_version +from typing import Any, Dict + +try: + check_min_version("4.34.0") +except: + raise ValueError("Please upgrade `transformers` to 4.34.0") + + +CONFIG_NAME = "config.json" + + +def save_weight( + input_dir: str, + output_dir: str, + shard_size: str +) -> str: + qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict() + for filepath in os.listdir(input_dir): + if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"): + with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f: + for key in f.keys(): + qwen_state_dict[key] = f.get_tensor(key) + + llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict() + torch_dtype = None + for key, value in qwen_state_dict.items(): + if torch_dtype is None: + torch_dtype = value.dtype + if "wte" in key: + llama2_state_dict["model.embed_tokens.weight"] = value + elif "ln_f" in key: + llama2_state_dict["model.norm.weight"] = value + else: + key = key.replace("transformer.h", "model.layers") + if "attn.c_attn" in key: + proj_size = value.size(0) // 3 + llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...] + llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...] + llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...] + elif "attn.c_proj" in key: + llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value + llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = ( + torch.zeros_like(value[:, 0]).squeeze() + ) + elif "ln_1" in key: + llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value + elif "ln_2" in key: + llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value + elif "mlp.w1" in key: + llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value + elif "mlp.w2" in key: + llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value + elif "mlp.c_proj" in key: + llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value + elif "lm_head" in key: + llama2_state_dict[key] = value + else: + raise KeyError("Unable to process key {}".format(key)) + + shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME) + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(output_dir, shard_file)) + + if index is None: + print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME))) + else: + with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f: + json.dump(index, f, indent=2, sort_keys=True) + print("Model weights saved in {}".format(output_dir)) + + return str(torch_dtype).replace("torch.", "") + + +def save_config( + input_dir: str, + output_dir: str, + torch_dtype: str +): + with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f: + qwen_config_dict: Dict[str, Any] = json.load(f) + + llama2_config_dict: Dict[str, Any] = OrderedDict() + llama2_config_dict["architectures"] = ["LlamaForCausalLM"] + llama2_config_dict["hidden_act"] = "silu" + llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"] + llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"] + llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"] // 2 + llama2_config_dict["max_position_embeddings"] = qwen_config_dict["max_position_embeddings"] + llama2_config_dict["model_type"] = "llama" + llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"] + llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"] + llama2_config_dict["num_key_value_heads"] = qwen_config_dict["hidden_size"] // qwen_config_dict["kv_channels"] + llama2_config_dict["pretraining_tp"] = 1 + llama2_config_dict["rms_norm_eps"] = qwen_config_dict["layer_norm_epsilon"] + llama2_config_dict["rope_scaling"] = None + llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"] + llama2_config_dict["torch_dtype"] = torch_dtype + llama2_config_dict["transformers_version"] = "4.34.0" + llama2_config_dict["use_cache"] = True + llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"] + llama2_config_dict["attention_bias"] = True + + with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f: + json.dump(llama2_config_dict, f, indent=2) + print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME))) + + +def llamafy_qwen( + input_dir: str, + output_dir: str, + shard_size: str +): + try: + os.makedirs(output_dir, exist_ok=False) + except Exception as e: + raise print("Output dir already exists", e) + + torch_dtype = save_weight(input_dir, output_dir, shard_size) + save_config(input_dir, output_dir, torch_dtype) + + +if __name__ == "__main__": + fire.Fire(llamafy_qwen)