Update llamafy_internlm2.py

Former-commit-id: 1f1a7bcee5a5bb0fa17b13aa6393bfba89451dd7
This commit is contained in:
hiyouga 2024-01-18 00:49:31 +08:00
parent 97b52c7fdf
commit 636d8a886c

View File

@ -49,8 +49,8 @@ def save_weight(
proj_size = value.size(0)
num_q_heads = internlm2_config_dict["num_attention_heads"]
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
q_size = proj_size // (num_q_heads + num_kv_heads) * num_q_heads
kv_size = proj_size // (num_q_heads + num_kv_heads) * num_kv_heads
q_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_q_heads
kv_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]