add template match and stage in webui

This commit is contained in:
codemayq
2023-08-14 20:42:59 +08:00
parent ec94274ca1
commit 79c68e5527
6 changed files with 77 additions and 14 deletions

View File

@@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS
from llmtuner.extras.constants import SUPPORTED_MODELS, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL, DEFAULT_TEMPLATE
DEFAULT_CACHE_DIR = "cache"
@@ -48,6 +48,25 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def get_template(
template: str,
model_name: str,
) -> str:
if template and template != "default":
return template
if model_name == "Custom":
model_name_or_path = get_model_path(model_name)
# get last dir
basename = os.path.basename(model_name_or_path)
# prefix match
for k, v in DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL.items():
if basename.startswith(k):
return v
return "default"
return DEFAULT_TEMPLATE.get(model_name.split("-")[0], "default")
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)