Former-commit-id: 6d641af70361756131eaee456362909bd82a6c58
This commit is contained in:
hiyouga 2024-04-17 22:17:19 +08:00
parent dac0c1a52b
commit f9c859e97b
15 changed files with 75 additions and 32 deletions

View File

@ -5,14 +5,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from vllm import AsyncLLMEngine
from ..data import Template from ..data import Template
from ..extras.packages import is_vllm_available
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
if is_vllm_available():
from vllm import AsyncLLMEngine
@dataclass @dataclass
class Response: class Response:

View File

@ -25,6 +25,10 @@ def is_galore_available():
return _is_package_available("galore_torch") return _is_package_available("galore_torch")
def is_gradio_available():
return _is_package_available("gradio")
def is_jieba_available(): def is_jieba_available():
return _is_package_available("jieba") return _is_package_available("jieba")

View File

@ -2,12 +2,10 @@ import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Sequence, Tuple
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from ..chat import ChatModel from ..chat import ChatModel
from ..data import Role from ..data import Role
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..extras.packages import is_gradio_available
from .common import get_save_dir from .common import get_save_dir
from .locales import ALERTS from .locales import ALERTS
@ -17,6 +15,11 @@ if TYPE_CHECKING:
from .manager import Manager from .manager import Manager
if is_gradio_available():
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager

View File

@ -3,7 +3,6 @@ import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import gradio as gr
from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME from peft.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME
from ..extras.constants import ( from ..extras.constants import (
@ -17,6 +16,11 @@ from ..extras.constants import (
DownloadSource, DownloadSource,
) )
from ..extras.misc import use_modelscope from ..extras.misc import use_modelscope
from ..extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME} ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}

View File

@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Dict, Tuple from typing import TYPE_CHECKING, Dict, Tuple
import gradio as gr
from ...data import Role from ...data import Role
from ...extras.packages import is_gradio_available
from ..utils import check_json_schema from ..utils import check_json_schema
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@ -2,9 +2,12 @@ import json
import os import os
from typing import TYPE_CHECKING, Any, Dict, List, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import gradio as gr
from ...extras.constants import DATA_CONFIG from ...extras.constants import DATA_CONFIG
from ...extras.packages import is_gradio_available
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -1,11 +1,14 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, list_dataset from ..common import DEFAULT_DATA_DIR, list_dataset
from .data import create_preview_box from .data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@ -1,12 +1,15 @@
from typing import TYPE_CHECKING, Dict, Generator, List from typing import TYPE_CHECKING, Dict, Generator, List
import gradio as gr from ...extras.packages import is_gradio_available
from ...train import export_model from ...train import export_model
from ..common import get_save_dir from ..common import get_save_dir
from ..locales import ALERTS from ..locales import ALERTS
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@ -1,10 +1,13 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr from ...extras.packages import is_gradio_available
from .chatbot import create_chat_box from .chatbot import create_chat_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@ -1,13 +1,16 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from ...data import templates from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.constants import METHODS, SUPPORTED_MODELS
from ...extras.packages import is_gradio_available
from ..common import get_model_path, get_template, list_adapters, save_config from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize from ..utils import can_quantize
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@ -1,13 +1,17 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
from ...extras.constants import TRAINING_STAGES from ...extras.constants import TRAINING_STAGES
from ...extras.packages import is_gradio_available
from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset from ..common import DEFAULT_DATA_DIR, autoset_packing, list_adapters, list_dataset
from ..components.data import create_preview_box from ..components.data import create_preview_box
if is_gradio_available():
import gradio as gr
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component

View File

@ -1,7 +1,6 @@
from typing import Any, Dict, Generator from typing import Any, Dict, Generator
from gradio.components import Component # cannot use TYPE_CHECKING here from ..extras.packages import is_gradio_available
from .chatter import WebChatModel from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config from .common import get_model_path, list_dataset, load_config
from .locales import LOCALES from .locales import LOCALES
@ -10,6 +9,10 @@ from .runner import Runner
from .utils import get_time from .utils import get_time
if is_gradio_available():
from gradio.components import Component # cannot use TYPE_CHECKING here
class Engine: class Engine:
def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None: def __init__(self, demo_mode: bool = False, pure_chat: bool = False) -> None:
self.demo_mode = demo_mode self.demo_mode = demo_mode

View File

@ -1,5 +1,4 @@
import gradio as gr from ..extras.packages import is_gradio_available
from .common import save_config from .common import save_config
from .components import ( from .components import (
create_chat_box, create_chat_box,
@ -13,6 +12,10 @@ from .css import CSS
from .engine import Engine from .engine import Engine
if is_gradio_available():
import gradio as gr
def create_ui(demo_mode: bool = False) -> gr.Blocks: def create_ui(demo_mode: bool = False) -> gr.Blocks:
engine = Engine(demo_mode=demo_mode, pure_chat=False) engine = Engine(demo_mode=demo_mode, pure_chat=False)

View File

@ -4,9 +4,7 @@ import time
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Generator from typing import TYPE_CHECKING, Any, Dict, Generator
import gradio as gr
import transformers import transformers
from gradio.components import Component # cannot use TYPE_CHECKING here
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from transformers.utils import is_torch_cuda_available from transformers.utils import is_torch_cuda_available
@ -14,12 +12,18 @@ from ..extras.callbacks import LogCallback
from ..extras.constants import TRAINING_STAGES from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc from ..extras.misc import get_device_count, torch_gc
from ..extras.packages import is_gradio_available
from ..train import run_exp from ..train import run_exp
from .common import get_module, get_save_dir, load_args, load_config, save_args from .common import get_module, get_save_dir, load_args, load_config, save_args
from .locales import ALERTS from .locales import ALERTS
from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar from .utils import gen_cmd, gen_plot, get_eval_results, update_process_bar
if is_gradio_available():
import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
if TYPE_CHECKING: if TYPE_CHECKING:
from .manager import Manager from .manager import Manager

View File

@ -3,21 +3,24 @@ import os
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Optional
import gradio as gr from ..extras.packages import is_gradio_available, is_matplotlib_available
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth from ..extras.ploting import smooth
from .locales import ALERTS from .locales import ALERTS
if TYPE_CHECKING: if is_gradio_available():
from ..extras.callbacks import LogCallback import gradio as gr
if is_matplotlib_available(): if is_matplotlib_available():
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
if TYPE_CHECKING:
from ..extras.callbacks import LogCallback
def update_process_bar(callback: "LogCallback") -> "gr.Slider": def update_process_bar(callback: "LogCallback") -> "gr.Slider":
if not callback.max_steps: if not callback.max_steps:
return gr.Slider(visible=False) return gr.Slider(visible=False)