update webui

This commit is contained in:
hiyouga
2023-08-14 22:45:26 +08:00
parent adb0f186e9
commit 9d0f6214b6
8 changed files with 47 additions and 78 deletions

View File

@@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL, DEFAULT_TEMPLATE
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
DEFAULT_CACHE_DIR = "cache"
@@ -48,20 +48,10 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def get_template(
model_name: str,
) -> str:
if model_name == "Custom":
model_name_or_path = get_model_path(model_name)
# get last dir
basename = os.path.basename(model_name_or_path)
# prefix match
for k, v in DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL.items():
if basename.startswith(k):
return v
return "default"
return DEFAULT_TEMPLATE.get(model_name.split("-")[0], "default")
def get_template(model_name: str) -> str:
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
return "default"
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: