diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 019812c7..0ad2929e 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -137,8 +137,13 @@ def get_template(model_name: str) -> str: r""" Gets the template name if the model is a chat model. """ - if model_name and model_name.endswith("Chat") and get_prefix(model_name) in DEFAULT_TEMPLATE: + if ( + model_name + and any(suffix in model_name for suffix in ("-Chat", "-Instruct")) + and get_prefix(model_name) in DEFAULT_TEMPLATE + ): return DEFAULT_TEMPLATE[get_prefix(model_name)] + return "default"