mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
Update llamafy_internlm2.py
Former-commit-id: c84a387c2c0e6ac5a33fd233bcfba4363eccb4cb
This commit is contained in:
parent
85a2e81dc2
commit
925b61ff99
@ -49,8 +49,8 @@ def save_weight(
|
||||
proj_size = value.size(0)
|
||||
num_q_heads = internlm2_config_dict["num_attention_heads"]
|
||||
num_kv_heads = internlm2_config_dict["num_key_value_heads"]
|
||||
q_size = proj_size // (num_q_heads + num_kv_heads) * num_q_heads
|
||||
kv_size = proj_size // (num_q_heads + num_kv_heads) * num_kv_heads
|
||||
q_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_q_heads
|
||||
kv_size = proj_size // (num_q_heads + 2 * num_kv_heads) * num_kv_heads
|
||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.q_proj")] = value[:q_size, ...]
|
||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.k_proj")] = value[q_size:q_size+kv_size, ...]
|
||||
llama2_state_dict[key.replace("attention.wqkv", "self_attn.v_proj")] = value[q_size+kv_size:, ...]
|
||||
|
Loading…
x
Reference in New Issue
Block a user