mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[webui] support other hub (#8567)
This commit is contained in:
		
							parent
							
								
									4b0ec83928
								
							
						
					
					
						commit
						12ed792db9
					
				@ -18,7 +18,7 @@ scipy
 | 
			
		||||
# model and tokenizer
 | 
			
		||||
sentencepiece
 | 
			
		||||
tiktoken
 | 
			
		||||
modelscope>=1.23
 | 
			
		||||
modelscope>=1.14.0
 | 
			
		||||
hf-transfer
 | 
			
		||||
# python
 | 
			
		||||
fire
 | 
			
		||||
 | 
			
		||||
@ -91,7 +91,7 @@ def _load_single_dataset(
 | 
			
		||||
        raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
 | 
			
		||||
 | 
			
		||||
    if dataset_attr.load_from == "ms_hub":
 | 
			
		||||
        check_version("modelscope>=1.11.0", mandatory=True)
 | 
			
		||||
        check_version("modelscope>=1.14.0", mandatory=True)
 | 
			
		||||
        from modelscope import MsDataset  # type: ignore
 | 
			
		||||
        from modelscope.utils.config_ds import MS_DATASETS_CACHE  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -268,8 +268,13 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
 | 
			
		||||
        return model_args.model_name_or_path
 | 
			
		||||
 | 
			
		||||
    if use_modelscope():
 | 
			
		||||
        check_version("modelscope>=1.11.0", mandatory=True)
 | 
			
		||||
        check_version("modelscope>=1.14.0", mandatory=True)
 | 
			
		||||
        from modelscope import snapshot_download  # type: ignore
 | 
			
		||||
        from modelscope.hub.api import HubApi  # type: ignore
 | 
			
		||||
 | 
			
		||||
        if model_args.ms_hub_token:
 | 
			
		||||
            api = HubApi()
 | 
			
		||||
            api.login(model_args.ms_hub_token)
 | 
			
		||||
 | 
			
		||||
        revision = "master" if model_args.model_revision == "main" else model_args.model_revision
 | 
			
		||||
        return snapshot_download(
 | 
			
		||||
@ -314,5 +319,5 @@ def fix_proxy(ipv6_enabled: bool = False) -> None:
 | 
			
		||||
    r"""Fix proxy settings for gradio ui."""
 | 
			
		||||
    os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
 | 
			
		||||
    if ipv6_enabled:
 | 
			
		||||
        for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
 | 
			
		||||
            os.environ.pop(name, None)
 | 
			
		||||
        os.environ.pop("http_proxy", None)
 | 
			
		||||
        os.environ.pop("HTTP_PROXY", None)
 | 
			
		||||
 | 
			
		||||
@ -77,14 +77,19 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
 | 
			
		||||
        with open(_get_config_path(), encoding="utf-8") as f:
 | 
			
		||||
            return safe_load(f)
 | 
			
		||||
    except Exception:
 | 
			
		||||
        return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
 | 
			
		||||
        return {"lang": None, "hub_name": None, "last_model": None, "path_dict": {}, "cache_dir": None}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
 | 
			
		||||
def save_config(
 | 
			
		||||
    lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None
 | 
			
		||||
) -> None:
 | 
			
		||||
    r"""Save user config."""
 | 
			
		||||
    os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
 | 
			
		||||
    user_config = load_config()
 | 
			
		||||
    user_config["lang"] = lang or user_config["lang"]
 | 
			
		||||
    if hub_name:
 | 
			
		||||
        user_config["hub_name"] = hub_name
 | 
			
		||||
 | 
			
		||||
    if model_name:
 | 
			
		||||
        user_config["last_model"] = model_name
 | 
			
		||||
 | 
			
		||||
@ -247,7 +252,7 @@ def create_ds_config() -> None:
 | 
			
		||||
        "stage": 2,
 | 
			
		||||
        "allgather_partitions": True,
 | 
			
		||||
        "allgather_bucket_size": 5e8,
 | 
			
		||||
        "overlap_comm": True,
 | 
			
		||||
        "overlap_comm": False,
 | 
			
		||||
        "reduce_scatter": True,
 | 
			
		||||
        "reduce_bucket_size": 5e8,
 | 
			
		||||
        "contiguous_gradients": True,
 | 
			
		||||
@ -262,7 +267,7 @@ def create_ds_config() -> None:
 | 
			
		||||
 | 
			
		||||
    ds_config["zero_optimization"] = {
 | 
			
		||||
        "stage": 3,
 | 
			
		||||
        "overlap_comm": True,
 | 
			
		||||
        "overlap_comm": False,
 | 
			
		||||
        "contiguous_gradients": True,
 | 
			
		||||
        "sub_group_size": 1e9,
 | 
			
		||||
        "reduce_bucket_size": "auto",
 | 
			
		||||
 | 
			
		||||
@ -16,9 +16,10 @@ from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from ...data import TEMPLATES
 | 
			
		||||
from ...extras.constants import METHODS, SUPPORTED_MODELS
 | 
			
		||||
from ...extras.misc import use_modelscope, use_openmind
 | 
			
		||||
from ...extras.packages import is_gradio_available
 | 
			
		||||
from ..common import save_config
 | 
			
		||||
from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints
 | 
			
		||||
from ..control import can_quantize, can_quantize_to, check_template, get_model_info, list_checkpoints, switch_hub
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_gradio_available():
 | 
			
		||||
@ -33,8 +34,10 @@ def create_top() -> dict[str, "Component"]:
 | 
			
		||||
    with gr.Row():
 | 
			
		||||
        lang = gr.Dropdown(choices=["en", "ru", "zh", "ko", "ja"], value=None, scale=1)
 | 
			
		||||
        available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
 | 
			
		||||
        model_name = gr.Dropdown(choices=available_models, value=None, scale=3)
 | 
			
		||||
        model_path = gr.Textbox(scale=3)
 | 
			
		||||
        model_name = gr.Dropdown(choices=available_models, value=None, scale=2)
 | 
			
		||||
        model_path = gr.Textbox(scale=2)
 | 
			
		||||
        default_hub = "modelscope" if use_modelscope() else "openmind" if use_openmind() else "huggingface"
 | 
			
		||||
        hub_name = gr.Dropdown(choices=["huggingface", "modelscope", "openmind"], value=default_hub, scale=2)
 | 
			
		||||
 | 
			
		||||
    with gr.Row():
 | 
			
		||||
        finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
 | 
			
		||||
@ -50,18 +53,25 @@ def create_top() -> dict[str, "Component"]:
 | 
			
		||||
    model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
 | 
			
		||||
        list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
 | 
			
		||||
    ).then(check_template, [lang, template])
 | 
			
		||||
    model_name.input(save_config, inputs=[lang, model_name], queue=False)
 | 
			
		||||
    model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False)
 | 
			
		||||
    model_name.input(save_config, inputs=[lang, hub_name, model_name], queue=False)
 | 
			
		||||
    model_path.input(save_config, inputs=[lang, hub_name, model_name, model_path], queue=False)
 | 
			
		||||
    finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then(
 | 
			
		||||
        list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
 | 
			
		||||
    )
 | 
			
		||||
    checkpoint_path.focus(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False)
 | 
			
		||||
    quantization_method.change(can_quantize_to, [quantization_method], [quantization_bit], queue=False)
 | 
			
		||||
    hub_name.change(switch_hub, inputs=[hub_name], queue=False).then(
 | 
			
		||||
        get_model_info, [model_name], [model_path, template], queue=False
 | 
			
		||||
    ).then(list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False).then(
 | 
			
		||||
        check_template, [lang, template]
 | 
			
		||||
    )
 | 
			
		||||
    hub_name.input(save_config, inputs=[lang, hub_name], queue=False)
 | 
			
		||||
 | 
			
		||||
    return dict(
 | 
			
		||||
        lang=lang,
 | 
			
		||||
        model_name=model_name,
 | 
			
		||||
        model_path=model_path,
 | 
			
		||||
        hub_name=hub_name,
 | 
			
		||||
        finetuning_type=finetuning_type,
 | 
			
		||||
        checkpoint_path=checkpoint_path,
 | 
			
		||||
        quantization_bit=quantization_bit,
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,15 @@ if is_gradio_available():
 | 
			
		||||
    import gradio as gr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def switch_hub(hub_name: str) -> None:
 | 
			
		||||
    r"""Switch model hub.
 | 
			
		||||
 | 
			
		||||
    Inputs: top.hub_name
 | 
			
		||||
    """
 | 
			
		||||
    os.environ["USE_MODELSCOPE_HUB"] = "1" if hub_name == "modelscope" else "0"
 | 
			
		||||
    os.environ["USE_OPENMIND_HUB"] = "1" if hub_name == "openmind" else "0"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def can_quantize(finetuning_type: str) -> "gr.Dropdown":
 | 
			
		||||
    r"""Judge if the quantization is available in this finetuning type.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,8 +49,13 @@ class Engine:
 | 
			
		||||
    def resume(self):
 | 
			
		||||
        r"""Get the initial value of gradio components and restores training status if necessary."""
 | 
			
		||||
        user_config = load_config() if not self.demo_mode else {}  # do not use config in demo mode
 | 
			
		||||
        lang = user_config.get("lang", None) or "en"
 | 
			
		||||
        init_dict = {"top.lang": {"value": lang}, "infer.chat_box": {"visible": self.chatter.loaded}}
 | 
			
		||||
        lang = user_config.get("lang") or "en"
 | 
			
		||||
        hub_name = user_config.get("hub_name") or "huggingface"
 | 
			
		||||
        init_dict = {
 | 
			
		||||
            "top.lang": {"value": lang},
 | 
			
		||||
            "top.hub_name": {"value": hub_name},
 | 
			
		||||
            "infer.chat_box": {"visible": self.chatter.loaded},
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if not self.pure_chat:
 | 
			
		||||
            current_time = get_time()
 | 
			
		||||
 | 
			
		||||
@ -39,15 +39,13 @@ def create_ui(demo_mode: bool = False) -> "gr.Blocks":
 | 
			
		||||
    engine = Engine(demo_mode=demo_mode, pure_chat=False)
 | 
			
		||||
    hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
 | 
			
		||||
 | 
			
		||||
    with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
 | 
			
		||||
    with gr.Blocks(title=f"LLaMA Factory ({hostname})", css=CSS) as demo:
 | 
			
		||||
        title = gr.HTML()
 | 
			
		||||
        subtitle = gr.HTML()
 | 
			
		||||
        if demo_mode:
 | 
			
		||||
            gr.HTML("<h1><center>LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory</center></h1>")
 | 
			
		||||
            gr.HTML(
 | 
			
		||||
                '<h3><center>Visit <a href="https://github.com/hiyouga/LLaMA-Factory" target="_blank">'
 | 
			
		||||
                "LLaMA Factory</a> for details.</center></h3>"
 | 
			
		||||
            )
 | 
			
		||||
            gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
 | 
			
		||||
 | 
			
		||||
        engine.manager.add_elems("head", {"title": title, "subtitle": subtitle})
 | 
			
		||||
        engine.manager.add_elems("top", create_top())
 | 
			
		||||
        lang: gr.Dropdown = engine.manager.get_elem_by_id("top.lang")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,55 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
LOCALES = {
 | 
			
		||||
    "title": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "value": "<h1><center>LLaMA Factory: Unified Efficient Fine-Tuning of 100+ LLMs</center></h1>",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "value": "<h1><center>LLaMA Factory: Унифицированная эффективная тонкая настройка 100+ LLMs</center></h1>",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "value": "<h1><center>LLaMA Factory: 一站式大模型高效微调平台</center></h1>",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "value": "<h1><center>LLaMA Factory: 100+ LLMs를 위한 통합 효율적인 튜닝</center></h1>",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "value": "<h1><center>LLaMA Factory: 100+ LLMs の統合効率的なチューニング</center></h1>",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "subtitle": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "value": (
 | 
			
		||||
                "<h3><center>Visit <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
 | 
			
		||||
                "GitHub Page</a></center></h3>"
 | 
			
		||||
            ),
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "value": (
 | 
			
		||||
                "<h3><center>Посетить <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
 | 
			
		||||
                "страницу GitHub</a></center></h3>"
 | 
			
		||||
            ),
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "value": (
 | 
			
		||||
                "<h3><center>访问 <a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
 | 
			
		||||
                "GitHub 主页</a></center></h3>"
 | 
			
		||||
            ),
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "value": (
 | 
			
		||||
                "<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
 | 
			
		||||
                "GitHub 페이지</a>를 방문하세요.</center></h3>"
 | 
			
		||||
            ),
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "value": (
 | 
			
		||||
                "<h3><center><a href='https://github.com/hiyouga/LLaMA-Factory' target='_blank'>"
 | 
			
		||||
                "GitHub ページ</a>にアクセスする</center></h3>"
 | 
			
		||||
            ),
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "lang": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Language",
 | 
			
		||||
@ -74,6 +123,28 @@ LOCALES = {
 | 
			
		||||
            "info": "事前学習済みモデルへのパス、または Hugging Face のモデル識別子。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "hub_name": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Hub name",
 | 
			
		||||
            "info": "Choose the model download source.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Имя хаба",
 | 
			
		||||
            "info": "Выберите источник загрузки модели.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "模型下载源",
 | 
			
		||||
            "info": "选择模型下载源。(网络受限环境推荐使用 ModelScope)",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "모델 다운로드 소스",
 | 
			
		||||
            "info": "모델 다운로드 소스를 선택하세요.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "モデルダウンロードソース",
 | 
			
		||||
            "info": "モデルをダウンロードするためのソースを選択してください。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "finetuning_type": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Finetuning method",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user