From 636d8a886c1b6279298449d8cdc969b9862fcfae Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 18 Jan 2024 00:49:31 +0800 Subject: [PATCH] Update llamafy_internlm2.py Former-commit-id: 1f1a7bcee5a5bb0fa17b13aa6393bfba89451dd7 --- tests/llamafy_internlm2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/llamafy_internlm2.py b/tests/llamafy_internlm2.py index 996aef4b..e6ca8058 100644 --- a/tests/llamafy_internlm2.py +++ b/tests/llamafy_internlm2.py @@ -49,8 +49,8 @@ def save_weight( proj_size = value.size(0) num_q_heads = internlm2_config_dict["num_attention_heads"] num_kv_heads = internlm2_config_dict["num_key_value_heads"] - q_size = proj_size // (num_q_heads + num_kv_heads) * num_q_heads - kv_size = proj_size // (num_q_heads + num_kv_heads) * num_kv_heads + q_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_q_heads + kv_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_kv_heads llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...] llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]