fix llama board

Former-commit-id: c46879575f
This commit is contained in:
hiyouga
2023-12-16 22:17:37 +08:00
parent 8154b4bdf6
commit 16cc0321f2

View File

@@ -1,6 +1,7 @@
import os import os
import json import json
import gradio as gr import gradio as gr
from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
@@ -52,8 +53,8 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
def get_model_path(model_name: str) -> str: def get_model_path(model_name: str) -> str:
user_config = load_config() user_config = load_config()
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, []) path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, "") model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
if ( if (
use_modelscope() use_modelscope()
and path_dict.get(DownloadSource.MODELSCOPE) and path_dict.get(DownloadSource.MODELSCOPE)