diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 523ba91b..6c8bc8e3 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -46,11 +46,16 @@ SUPPORTED_MODELS = { "Baichuan-7B": "baichuan-inc/Baichuan-7B", "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", + "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base", + "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base", + "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat", + "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat", "InternLM-7B": "internlm/internlm-7b", "InternLM-7B-Chat": "internlm/internlm-chat-7b", "Qwen-7B": "Qwen/Qwen-7B", "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", "XVERSE-13B": "xverse/XVERSE-13B", + "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat", "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b" } @@ -62,6 +67,7 @@ DEFAULT_MODULE = { "BLOOMZ": "query_key_value", "Falcon": "query_key_value", "Baichuan": "W_pack", + "Baichuan2": "W_pack", "InternLM": "q_proj,v_proj", "Qwen": "c_attn", "XVERSE": "q_proj,v_proj", @@ -72,7 +78,9 @@ DEFAULT_TEMPLATE = { "LLaMA2": "llama2", "ChineseLLaMA2": "llama2_zh", "Baichuan": "baichuan", + "Baichuan2": "baichuan", "InternLM": "intern", "Qwen": "chatml", + "XVERSE": "xverse", "ChatGLM2": "chatglm2" } diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index b57b1c8f..175cb6e9 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,3 +1,4 @@ +import gc import torch from typing import TYPE_CHECKING, List, Optional, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList @@ -98,6 +99,7 @@ def torch_gc() -> None: r""" Collects GPU memory. """ + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/extras/template.py index beb8e1f9..b4af406c 100644 --- a/src/llmtuner/extras/template.py +++ b/src/llmtuner/extras/template.py @@ -490,8 +490,7 @@ register_template( {"token": ""} # assistant token ], system="", - sep=[], - stop_words=[] + sep=[] )