diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 4e1ea0ae..728fbd6b 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -49,11 +49,8 @@ def get_model_path(model_name: str) -> str: def get_template( - template: str, model_name: str, ) -> str: - if template and template != "default": - return template if model_name == "Custom": model_name_or_path = get_model_path(model_name) # get last dir diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 8611e280..caceb1e5 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -4,7 +4,7 @@ import gradio as gr from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.template import templates -from llmtuner.webui.common import list_checkpoint, get_model_path, save_config +from llmtuner.webui.common import list_checkpoint, get_model_path, save_config, get_template from llmtuner.webui.utils import can_quantize if TYPE_CHECKING: @@ -39,6 +39,7 @@ def create_top() -> Dict[str, "Component"]: ) # do not save config since the below line will save model_path.change(save_config, [lang, model_name, model_path]) + model_path.change(get_template, [model_name], [template]) finetuning_type.change( list_checkpoint, [model_name, finetuning_type], [checkpoints] diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 00918297..07811258 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -113,7 +113,7 @@ class Runner: checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, - template=get_template(template, model_name), + template=template, source_prefix=source_prefix, dataset_dir=dataset_dir, dataset=",".join(dataset), @@ -197,7 +197,7 @@ class Runner: checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None, - template=get_template(template, model_name), + template=template, source_prefix=source_prefix, dataset_dir=dataset_dir, dataset=",".join(dataset),