From aad79127e6759c07f45c360deefae2f7c9a20cc5 Mon Sep 17 00:00:00 2001 From: huniu20 Date: Thu, 10 Oct 2024 16:46:34 +0800 Subject: [PATCH] 1. add model and dataset info to support webui Former-commit-id: 92f6226f3fecbd9af744a7232dda2c68b2bb0d86 --- README.md | 2 +- README_zh.md | 3 +-- src/llamafactory/data/loader.py | 2 +- src/llamafactory/data/parser.py | 1 + src/llamafactory/extras/constants.py | 12 ++++++++++++ src/llamafactory/webui/common.py | 9 ++++++++- 6 files changed, 24 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 40253b16..621a1c21 100644 --- a/README.md +++ b/README.md @@ -417,7 +417,7 @@ Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaij ### Data Preparation -Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can use datasets on HuggingFace hub, ModelScope hub, modelers hub or load the dataset in local disk. +Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use datasets on HuggingFace / ModelScope / Modelers hub or load the dataset in local disk. > [!NOTE] > Please update `data/dataset_info.json` to use your custom dataset. diff --git a/README_zh.md b/README_zh.md index 61aa860c..1d18ad16 100644 --- a/README_zh.md +++ b/README_zh.md @@ -417,8 +417,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh ### 数据准备 -关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace , ModelScope 或者 Modelers 上的数据集或加载本地数据集。 - +关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。你可以使用 HuggingFace / ModelScope / Modelers 上的数据集或加载本地数据集。 > [!NOTE] > 使用自定义数据集时,请更新 `data/dataset_info.json` 文件。 diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 03a38fa8..0849b603 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -94,7 +94,7 @@ def _load_single_dataset( name=data_name, data_dir=data_dir, data_files=data_files, - split=data_args.split, + split=dataset_attr.split, cache_dir=cache_dir, token=model_args.ms_hub_token, streaming=(data_args.streaming and (dataset_attr.load_from != "file")), diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 696cd488..879264bb 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -98,6 +98,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) - dataset_list: List["DatasetAttr"] = [] for name in dataset_names: if dataset_info is None: # dataset_dir is ONLINE + load_from = None if use_openmind(): load_from = "om_hub" if use_modelscope(): diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 7ed20c2b..bf2386f4 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -107,6 +107,7 @@ VISION_MODELS = set() class DownloadSource(str, Enum): DEFAULT = "hf" MODELSCOPE = "ms" + MODELERS = "om" def register_model_group( @@ -163,14 +164,17 @@ register_model_group( "Baichuan2-13B-Base": { DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base", + DownloadSource.MODELERS: "Baichuan/Baichuan2_13b_base_pt" }, "Baichuan2-7B-Chat": { DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat", + DownloadSource.MODELERS: "Baichuan/Baichuan2_7b_chat_pt" }, "Baichuan2-13B-Chat": { DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat", + DownloadSource.MODELERS: "Baichuan/Baichuan2_13b_chat_pt" }, }, template="baichuan2", @@ -559,6 +563,7 @@ register_model_group( "Gemma-2-9B-Instruct": { DownloadSource.DEFAULT: "google/gemma-2-9b-it", DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", + DownloadSource.MODELERS: "LlamaFactory/Qwen2-VL-2B-Instruct" }, "Gemma-2-27B-Instruct": { DownloadSource.DEFAULT: "google/gemma-2-27b-it", @@ -656,6 +661,7 @@ register_model_group( "InternLM2.5-20B-Chat": { DownloadSource.DEFAULT: "internlm/internlm2_5-20b-chat", DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2_5-20b-chat", + DownloadSource.MODELERS: "Intern/internlm2_5-20b-chat" }, }, template="intern2", @@ -756,6 +762,7 @@ register_model_group( "Llama-3-8B-Chinese-Chat": { DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat", DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat", + DownloadSource.MODELERS: "HaM/Llama3-8B-Chinese-Chat", }, "Llama-3-70B-Chinese-Chat": { DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat", @@ -960,6 +967,7 @@ register_model_group( "MiniCPM3-4B-Chat": { DownloadSource.DEFAULT: "openbmb/MiniCPM3-4B", DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM3-4B", + DownloadSource.MODELERS: "LlamaFactory/MiniCPM3-4B" }, }, template="cpm3", @@ -1699,6 +1707,7 @@ register_model_group( "Qwen2-VL-2B-Instruct": { DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct", DownloadSource.MODELSCOPE: "qwen/Qwen2-VL-2B-Instruct", + DownloadSource.MODELERS: "LlamaFactory/Qwen2-VL-2B-Instruct" }, "Qwen2-VL-7B-Instruct": { DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B-Instruct", @@ -1801,10 +1810,12 @@ register_model_group( "TeleChat-7B-Chat": { DownloadSource.DEFAULT: "Tele-AI/telechat-7B", DownloadSource.MODELSCOPE: "TeleAI/telechat-7B", + DownloadSource.MODELERS: "TeleAI/TeleChat-7B-pt" }, "TeleChat-12B-Chat": { DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B", DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B", + DownloadSource.MODELERS: "TeleAI/TeleChat-12B-pt", }, "TeleChat-12B-v2-Chat": { DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2", @@ -2023,6 +2034,7 @@ register_model_group( "Yi-1.5-6B-Chat": { DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat", DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat", + DownloadSource.MODELERS: "LlamaFactory/Yi-1.5-6B-Chat" }, "Yi-1.5-9B-Chat": { DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat", diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index d4e9be51..a078c976 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -31,7 +31,7 @@ from ..extras.constants import ( DownloadSource, ) from ..extras.logging import get_logger -from ..extras.misc import use_modelscope +from ..extras.misc import use_modelscope, use_openmind from ..extras.packages import is_gradio_available @@ -112,6 +112,13 @@ def get_model_path(model_name: str) -> str: ): # replace path model_path = path_dict.get(DownloadSource.MODELSCOPE) + if ( + use_openmind() + and path_dict.get(DownloadSource.MODELERS) + and model_path == path_dict.get(DownloadSource.DEFAULT) + ): # replace path + model_path = path_dict.get(DownloadSource.MODELERS) + return model_path