Update llamafy_internlm2.py

Former-commit-id: 28135d787d62ddcec67fbec0aa78bbd2bdccad1e
This commit is contained in:
hiyouga 2024-01-18 01:12:31 +08:00
parent 6fa117a0ff
commit 71306bbfb1

View File

@ -40,7 +40,7 @@ def save_weight(
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict() llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"): for key, value in tqdm(internlm2_state_dict.items(), desc="Convert format"):
if "output" in key: if "output" in key:
llama2_state_dict["lm_head"] = value llama2_state_dict[key.replace("output", "lm_head")] = value
elif "tok_embeddings" in key: elif "tok_embeddings" in key:
llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value llama2_state_dict[key.replace("tok_embeddings", "embed_tokens")] = value
elif "attention_norm" in key: elif "attention_norm" in key: