From eb835b693d22c154d986ba5ff31e5f878e2c00e8 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 1 Dec 2023 17:27:00 +0800 Subject: [PATCH] fix bug Former-commit-id: d9e52957e272e8133f1b37cf20d193084425e09e --- src/llmtuner/extras/constants.py | 4 +++- src/llmtuner/webui/common.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index c865b4f8..69f4510d 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -20,7 +20,7 @@ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] SUPPORTED_MODELS = OrderedDict() -MODELSCOPE_MODELS = OrderedDict() +ALL_OFFICIAL_MODELS = OrderedDict() TRAINING_STAGES = { "Supervised Fine-Tuning": "sft", @@ -43,12 +43,14 @@ def register_model_group( else: assert prefix == name.split("-")[0], "prefix should be identical." + ALL_OFFICIAL_MODELS[name] = [path] if isinstance(path, str) else list(path.values()) if not int(os.environ.get('USE_MODELSCOPE_HUB', '0')): # If path is a string, we treat it as a huggingface model-id by default. SUPPORTED_MODELS[name] = path["hf"] if isinstance(path, dict) else path elif isinstance(path, dict) and "ms" in path: # Use ModelScope modelhub SUPPORTED_MODELS[name] = path["ms"] + print(f'Supported models add {name}/{SUPPORTED_MODELS[name]}') if module is not None: DEFAULT_MODULE[prefix] = module if template is not None: diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 55d8942b..b21cad62 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -11,7 +11,7 @@ from transformers.utils import ( ADAPTER_SAFE_WEIGHTS_NAME ) -from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES +from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, ALL_OFFICIAL_MODELS, TRAINING_STAGES DEFAULT_CACHE_DIR = "cache" @@ -58,7 +58,10 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona def get_model_path(model_name: str) -> str: user_config = load_config() - return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "") + cached_path = user_config["path_dict"].get(model_name, None) + if cached_path in ALL_OFFICIAL_MODELS.get(model_name, []): + cached_path = None + return cached_path or SUPPORTED_MODELS.get(model_name, "") def get_prefix(model_name: str) -> str: