mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
add template match and stage in webui
This commit is contained in:
@@ -6,7 +6,7 @@ import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
|
||||
from llmtuner.extras.constants import SUPPORTED_MODELS
|
||||
from llmtuner.extras.constants import SUPPORTED_MODELS, DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL, DEFAULT_TEMPLATE
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
@@ -48,6 +48,25 @@ def get_model_path(model_name: str) -> str:
|
||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||
|
||||
|
||||
def get_template(
|
||||
template: str,
|
||||
model_name: str,
|
||||
) -> str:
|
||||
if template and template != "default":
|
||||
return template
|
||||
if model_name == "Custom":
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
# get last dir
|
||||
basename = os.path.basename(model_name_or_path)
|
||||
# prefix match
|
||||
for k, v in DEFAULT_TEMPLATE_WITH_CUSTOM_MODEL.items():
|
||||
if basename.startswith(k):
|
||||
return v
|
||||
return "default"
|
||||
|
||||
return DEFAULT_TEMPLATE.get(model_name.split("-")[0], "default")
|
||||
|
||||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
checkpoints = []
|
||||
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
|
||||
|
||||
Reference in New Issue
Block a user