fix llamafy_internlm2

This commit is contained in:
hiyouga
2024-01-18 00:26:14 +08:00
parent f1d7ca77b1
commit 7ff4c874d2

View File

@@ -65,7 +65,7 @@ def save_weight(
elif "w3" in key: elif "w3" in key:
llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value llama2_state_dict[key.replace("feed_forward.w3", "mlp.up_proj")] = value
else: else:
raise KeyError("Unable to process key {}".format(key)) llama2_state_dict[key] = value
weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME weights_name = SAFE_WEIGHTS_NAME if save_safetensors else WEIGHTS_NAME
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name) shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=weights_name)