mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 20:52:59 +08:00
fix llama board
Former-commit-id: c46879575f434b2b458bddae6db63b227db4202e
This commit is contained in:
parent
8154b4bdf6
commit
16cc0321f2
@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
|
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
|
||||||
|
|
||||||
@ -52,8 +53,8 @@ def save_config(lang: str, model_name: Optional[str] = None, model_path: Optiona
|
|||||||
|
|
||||||
def get_model_path(model_name: str) -> str:
|
def get_model_path(model_name: str) -> str:
|
||||||
user_config = load_config()
|
user_config = load_config()
|
||||||
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, [])
|
path_dict: Dict[DownloadSource, str] = SUPPORTED_MODELS.get(model_name, defaultdict(str))
|
||||||
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, "")
|
model_path = user_config["path_dict"].get(model_name, None) or path_dict.get(DownloadSource.DEFAULT, None)
|
||||||
if (
|
if (
|
||||||
use_modelscope()
|
use_modelscope()
|
||||||
and path_dict.get(DownloadSource.MODELSCOPE)
|
and path_dict.get(DownloadSource.MODELSCOPE)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user