mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
update webui
This commit is contained in:
@@ -6,7 +6,7 @@ import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
from llmtuner.extras.constants import SUPPORTED_MODELS, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL, DEFAULT_TEMPLATE
|
||||
from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
@@ -48,20 +48,10 @@ def get_model_path(model_name: str) -> str:
|
||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||
|
||||
|
||||
def get_template(
|
||||
model_name: str,
|
||||
) -> str:
|
||||
if model_name == "Custom":
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
# get last dir
|
||||
basename = os.path.basename(model_name_or_path)
|
||||
# prefix match
|
||||
for k, v in DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL.items():
|
||||
if basename.startswith(k):
|
||||
return v
|
||||
return "default"
|
||||
|
||||
return DEFAULT_TEMPLATE.get(model_name.split("-")[0], "default")
|
||||
def get_template(model_name: str) -> str:
|
||||
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
|
||||
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
|
||||
return "default"
|
||||
|
||||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
|
||||
Reference in New Issue
Block a user