From f9c859e97b52a6b3efda06735506144654e57b96 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 17 Apr 2024 22:17:19 +0800 Subject: [PATCH] fix #3317 Former-commit-id: 6d641af70361756131eaee456362909bd82a6c58 --- src/llmtuner/chat/base_engine.py | 5 +---- src/llmtuner/extras/packages.py | 4 ++++ src/llmtuner/webui/chatter.py | 9 ++++++--- src/llmtuner/webui/common.py | 6 +++++- src/llmtuner/webui/components/chatbot.py | 7 +++++-- src/llmtuner/webui/components/data.py | 7 +++++-- src/llmtuner/webui/components/eval.py | 7 +++++-- src/llmtuner/webui/components/export.py | 7 +++++-- src/llmtuner/webui/components/infer.py | 7 +++++-- src/llmtuner/webui/components/top.py | 7 +++++-- src/llmtuner/webui/components/train.py | 6 +++++- src/llmtuner/webui/engine.py | 7 +++++-- src/llmtuner/webui/interface.py | 7 +++++-- src/llmtuner/webui/runner.py | 8 ++++++-- src/llmtuner/webui/utils.py | 13 ++++++++----- 15 files changed, 75 insertions(+), 32 deletions(-) diff --git a/src/llmtuner/chat/base_engine.py b/src/llmtuner/chat/base_engine.py index c5db41da..e19db676 100644 --- a/src/llmtuner/chat/base_engine.py +++ b/src/llmtuner/chat/base_engine.py @@ -5,14 +5,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer + from vllm import AsyncLLMEngine from ..data import Template - from ..extras.packages import is_vllm_available from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments - if is_vllm_available(): - from vllm import AsyncLLMEngine - @dataclass class Response: diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py index b134ddab..8494cb2c 100644 --- a/src/llmtuner/extras/packages.py +++ b/src/llmtuner/extras/packages.py @@ -25,6 +25,10 @@ def is_galore_available(): return _is_package_available("galore_torch") +def is_gradio_available(): + return _is_package_available("gradio") + + def is_jieba_available(): return _is_package_available("jieba") diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 8c744153..479846ca 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -2,12 +2,10 @@ import json import os from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple -import gradio as gr -from gradio.components import Component # cannot use TYPE_CHECKING here - from ..chat import ChatModel from ..data import Role from ..extras.misc import torch_gc +from ..extras.packages import is_gradio_available from .common import get_save_dir from .locales import ALERTS @@ -17,6 +15,11 @@ if TYPE_CHECKING: from .manager import Manager +if is_gradio_available(): + import gradio as gr + from gradio.components import Component # cannot use TYPE_CHECKING here + + class WebChatModel(ChatModel): def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: self.manager = manager diff --git a/src/llmtuner/webui/common.py b/src/llmtuner/webui/common.py index 96ef2737..659c35c3 100644 --- a/src/llmtuner/webui/common.py +++ b/src/llmtuner/webui/common.py @@ -3,7 +3,6 @@ import os from collections import defaultdict from typing import Any, Dict, Optional -import gradio as gr from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME from ..extras.constants import ( @@ -17,6 +16,11 @@ from ..extras.constants import ( DownloadSource, ) from ..extras.misc import use_modelscope +from ..extras.packages import is_gradio_available + + +if is_gradio_available(): + import gradio as gr ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} diff --git a/src/llmtuner/webui/components/chatbot.py b/src/llmtuner/webui/components/chatbot.py index 8efd333c..82bc4f29 100644 --- a/src/llmtuner/webui/components/chatbot.py +++ b/src/llmtuner/webui/components/chatbot.py @@ -1,11 +1,14 @@ from typing import TYPE_CHECKING, Dict, Tuple -import gradio as gr - from ...data import Role +from ...extras.packages import is_gradio_available from ..utils import check_json_schema +if is_gradio_available(): + import gradio as gr + + if TYPE_CHECKING: from gradio.components import Component diff --git a/src/llmtuner/webui/components/data.py b/src/llmtuner/webui/components/data.py index 8e2e04bf..232b973d 100644 --- a/src/llmtuner/webui/components/data.py +++ b/src/llmtuner/webui/components/data.py @@ -2,9 +2,12 @@ import json import os from typing import TYPE_CHECKING, Any, Dict, List, Tuple -import gradio as gr - from ...extras.constants import DATA_CONFIG +from ...extras.packages import is_gradio_available + + +if is_gradio_available(): + import gradio as gr if TYPE_CHECKING: diff --git a/src/llmtuner/webui/components/eval.py b/src/llmtuner/webui/components/eval.py index d41ef857..0b3bfc8c 100644 --- a/src/llmtuner/webui/components/eval.py +++ b/src/llmtuner/webui/components/eval.py @@ -1,11 +1,14 @@ from typing import TYPE_CHECKING, Dict -import gradio as gr - +from ...extras.packages import is_gradio_available from ..common import DEFAULT_DATA_DIR, list_dataset from .data import create_preview_box +if is_gradio_available(): + import gradio as gr + + if TYPE_CHECKING: from gradio.components import Component diff --git a/src/llmtuner/webui/components/export.py b/src/llmtuner/webui/components/export.py index b394d75c..d9c2d8e4 100644 --- a/src/llmtuner/webui/components/export.py +++ b/src/llmtuner/webui/components/export.py @@ -1,12 +1,15 @@ from typing import TYPE_CHECKING, Dict, Generator, List -import gradio as gr - +from ...extras.packages import is_gradio_available from ...train import export_model from ..common import get_save_dir from ..locales import ALERTS +if is_gradio_available(): + import gradio as gr + + if TYPE_CHECKING: from gradio.components import Component diff --git a/src/llmtuner/webui/components/infer.py b/src/llmtuner/webui/components/infer.py index 1e56d432..d565347e 100644 --- a/src/llmtuner/webui/components/infer.py +++ b/src/llmtuner/webui/components/infer.py @@ -1,10 +1,13 @@ from typing import TYPE_CHECKING, Dict -import gradio as gr - +from ...extras.packages import is_gradio_available from .chatbot import create_chat_box +if is_gradio_available(): + import gradio as gr + + if TYPE_CHECKING: from gradio.components import Component diff --git a/src/llmtuner/webui/components/top.py b/src/llmtuner/webui/components/top.py index 6c5030cd..6cbf6e0d 100644 --- a/src/llmtuner/webui/components/top.py +++ b/src/llmtuner/webui/components/top.py @@ -1,13 +1,16 @@ from typing import TYPE_CHECKING, Dict -import gradio as gr - from ...data import templates from ...extras.constants import METHODS, SUPPORTED_MODELS +from ...extras.packages import is_gradio_available from ..common import get_model_path, get_template, list_adapters, save_config from ..utils import can_quantize +if is_gradio_available(): + import gradio as gr + + if TYPE_CHECKING: from gradio.components import Component diff --git a/src/llmtuner/webui/components/train.py b/src/llmtuner/webui/components/train.py index 10954c1b..eaa266d9 100644 --- a/src/llmtuner/webui/components/train.py +++ b/src/llmtuner/webui/components/train.py @@ -1,13 +1,17 @@ from typing import TYPE_CHECKING, Dict -import gradio as gr from transformers.trainer_utils import SchedulerType from ...extras.constants import TRAINING_STAGES +from ...extras.packages import is_gradio_available from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset from ..components.data import create_preview_box +if is_gradio_available(): + import gradio as gr + + if TYPE_CHECKING: from gradio.components import Component diff --git a/src/llmtuner/webui/engine.py b/src/llmtuner/webui/engine.py index 0ee7f047..65945533 100644 --- a/src/llmtuner/webui/engine.py +++ b/src/llmtuner/webui/engine.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Generator -from gradio.components import Component # cannot use TYPE_CHECKING here - +from ..extras.packages import is_gradio_available from .chatter import WebChatModel from .common import get_model_path, list_dataset, load_config from .locales import LOCALES @@ -10,6 +9,10 @@ from .runner import Runner from .utils import get_time +if is_gradio_available(): + from gradio.components import Component # cannot use TYPE_CHECKING here + + class Engine: def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: self.demo_mode = demo_mode diff --git a/src/llmtuner/webui/interface.py b/src/llmtuner/webui/interface.py index f89d3ca5..0359d082 100644 --- a/src/llmtuner/webui/interface.py +++ b/src/llmtuner/webui/interface.py @@ -1,5 +1,4 @@ -import gradio as gr - +from ..extras.packages import is_gradio_available from .common import save_config from .components import ( create_chat_box, @@ -13,6 +12,10 @@ from .css import CSS from .engine import Engine +if is_gradio_available(): + import gradio as gr + + def create_ui(demo_mode: bool = False) -> gr.Blocks: engine = Engine(demo_mode=demo_mode, pure_chat=False) diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index ef5379cd..12307234 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -4,9 +4,7 @@ import time from threading import Thread from typing import TYPE_CHECKING, Any, Dict, Generator -import gradio as gr import transformers -from gradio.components import Component # cannot use TYPE_CHECKING here from transformers.trainer import TRAINING_ARGS_NAME from transformers.utils import is_torch_cuda_available @@ -14,12 +12,18 @@ from ..extras.callbacks import LogCallback from ..extras.constants import TRAINING_STAGES from ..extras.logging import LoggerHandler from ..extras.misc import get_device_count, torch_gc +from ..extras.packages import is_gradio_available from ..train import run_exp from .common import get_module, get_save_dir, load_args, load_config, save_args from .locales import ALERTS from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar +if is_gradio_available(): + import gradio as gr + from gradio.components import Component # cannot use TYPE_CHECKING here + + if TYPE_CHECKING: from .manager import Manager diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index d96b1f6b..74f74e6a 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -3,21 +3,24 @@ import os from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, Optional -import gradio as gr - -from ..extras.packages import is_matplotlib_available +from ..extras.packages import is_gradio_available, is_matplotlib_available from ..extras.ploting import smooth from .locales import ALERTS -if TYPE_CHECKING: - from ..extras.callbacks import LogCallback +if is_gradio_available(): + import gradio as gr + if is_matplotlib_available(): import matplotlib.figure import matplotlib.pyplot as plt +if TYPE_CHECKING: + from ..extras.callbacks import LogCallback + + def update_process_bar(callback: "LogCallback") -> "gr.Slider": if not callback.max_steps: return gr.Slider(visible=False)