add Baichuan2 models

Former-commit-id: 62ce65c6282d2bbcb765354acc2819cc3e983a46
This commit is contained in:
hiyouga 2023-09-06 18:36:04 +08:00
parent 5b4f59c3f9
commit f9aee17f9d
3 changed files with 11 additions and 2 deletions

View File

@ -46,11 +46,16 @@ SUPPORTED_MODELS = {
"Baichuan-7B": "baichuan-inc/Baichuan-7B", "Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
"InternLM-7B": "internlm/internlm-7b", "InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b", "InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B", "Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B", "XVERSE-13B": "xverse/XVERSE-13B",
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b" "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
} }
@ -62,6 +67,7 @@ DEFAULT_MODULE = {
"BLOOMZ": "query_key_value", "BLOOMZ": "query_key_value",
"Falcon": "query_key_value", "Falcon": "query_key_value",
"Baichuan": "W_pack", "Baichuan": "W_pack",
"Baichuan2": "W_pack",
"InternLM": "q_proj,v_proj", "InternLM": "q_proj,v_proj",
"Qwen": "c_attn", "Qwen": "c_attn",
"XVERSE": "q_proj,v_proj", "XVERSE": "q_proj,v_proj",
@ -72,7 +78,9 @@ DEFAULT_TEMPLATE = {
"LLaMA2": "llama2", "LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh", "ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan", "Baichuan": "baichuan",
"Baichuan2": "baichuan",
"InternLM": "intern", "InternLM": "intern",
"Qwen": "chatml", "Qwen": "chatml",
"XVERSE": "xverse",
"ChatGLM2": "chatglm2" "ChatGLM2": "chatglm2"
} }

View File

@ -1,3 +1,4 @@
import gc
import torch import torch
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
@ -98,6 +99,7 @@ def torch_gc() -> None:
r""" r"""
Collects GPU memory. Collects GPU memory.
""" """
gc.collect()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()

View File

@ -490,8 +490,7 @@ register_template(
{"token": "<reserved_103>"} # assistant token {"token": "<reserved_103>"} # assistant token
], ],
system="", system="",
sep=[], sep=[]
stop_words=[]
) )