support function calling

This commit is contained in:
hiyouga
2024-01-18 09:54:23 +08:00
parent 28135d787d
commit d9f1cae351
69 changed files with 1329 additions and 1085 deletions

View File

@@ -1 +1,4 @@
from llmtuner.webui.interface import create_ui, create_web_demo
from .interface import create_ui, create_web_demo
__all__ = ["create_ui", "create_web_demo"]

View File

@@ -2,14 +2,14 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
from llmtuner.chat import ChatModel
from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS
from ..chat import ChatModel
from ..extras.misc import torch_gc
from ..hparams import GeneratingArguments
from .common import get_save_dir
from .locales import ALERTS
if TYPE_CHECKING:
from llmtuner.webui.manager import Manager
from .manager import Manager
class WebChatModel(ChatModel):

View File

@@ -5,7 +5,8 @@ from collections import defaultdict
from typing import Any, Dict, Optional
from peft.utils import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
from llmtuner.extras.constants import (
from ..extras.constants import (
DATA_CONFIG,
DEFAULT_MODULE,
DEFAULT_TEMPLATE,
PEFT_METHODS,
@@ -13,8 +14,7 @@ from llmtuner.extras.constants import (
TRAINING_STAGES,
DownloadSource
)
from llmtuner.extras.misc import use_modelscope
from llmtuner.hparams.data_args import DATA_CONFIG
from ..extras.misc import use_modelscope
ADAPTER_NAMES = {WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME}

View File

@@ -1,6 +1,11 @@
from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab
from llmtuner.webui.components.chatbot import create_chat_box
from .top import create_top
from .train import create_train_tab
from .eval import create_eval_tab
from .infer import create_infer_tab
from .export import create_export_tab
from .chatbot import create_chat_box
__all__ = [
"create_top", "create_train_tab", "create_eval_tab", "create_infer_tab", "create_export_tab", "create_chat_box"
]

View File

@@ -4,7 +4,8 @@ from typing import TYPE_CHECKING, Dict, Optional, Tuple
if TYPE_CHECKING:
from gradio.blocks import Block
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_chat_box(

View File

@@ -3,7 +3,7 @@ import json
import gradio as gr
from typing import TYPE_CHECKING, Any, Dict, Tuple
from llmtuner.webui.common import DATA_CONFIG
from ...extras.constants import DATA_CONFIG
if TYPE_CHECKING:
from gradio.components import Component

View File

@@ -1,12 +1,13 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from ..common import list_dataset, DEFAULT_DATA_DIR
from .data import create_preview_box
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@@ -1,13 +1,14 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict, Generator, List
from llmtuner.train import export_model
from llmtuner.webui.common import get_save_dir
from llmtuner.webui.locales import ALERTS
from ...train import export_model
from ..common import get_save_dir
from ..locales import ALERTS
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
GPTQ_BITS = ["8", "4", "3", "2"]

View File

@@ -1,11 +1,12 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.components.chatbot import create_chat_box
from .chatbot import create_chat_box
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@@ -1,10 +1,10 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.data.template import templates
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.webui.common import get_model_path, get_template, list_adapters, save_config
from llmtuner.webui.utils import can_quantize
from ...data import templates
from ...extras.constants import METHODS, SUPPORTED_MODELS
from ..common import get_model_path, get_template, list_adapters, save_config
from ..utils import can_quantize
if TYPE_CHECKING:
from gradio.components import Component

View File

@@ -2,14 +2,15 @@ import gradio as gr
from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_adapters, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import gen_plot
from ...extras.constants import TRAINING_STAGES
from ..common import list_adapters, list_dataset, DEFAULT_DATA_DIR
from ..components.data import create_preview_box
from ..utils import gen_plot
if TYPE_CHECKING:
from gradio.components import Component
from llmtuner.webui.engine import Engine
from ..engine import Engine
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:

View File

@@ -2,12 +2,12 @@ import gradio as gr
from gradio.components import Component # cannot use TYPE_CHECKING here
from typing import Any, Dict, Generator, Optional
from llmtuner.webui.chatter import WebChatModel
from llmtuner.webui.common import get_model_path, list_dataset, load_config
from llmtuner.webui.locales import LOCALES
from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner
from llmtuner.webui.utils import get_time
from .chatter import WebChatModel
from .common import get_model_path, list_dataset, load_config
from .locales import LOCALES
from .manager import Manager
from .runner import Runner
from .utils import get_time
class Engine:

View File

@@ -2,7 +2,7 @@ import gradio as gr
from typing import Optional
from transformers.utils.versions import require_version
from llmtuner.webui.components import (
from .components import (
create_top,
create_train_tab,
create_eval_tab,
@@ -10,9 +10,9 @@ from llmtuner.webui.components import (
create_export_tab,
create_chat_box
)
from llmtuner.webui.common import save_config
from llmtuner.webui.css import CSS
from llmtuner.webui.engine import Engine
from .common import save_config
from .css import CSS
from .engine import Engine
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")

View File

@@ -9,17 +9,17 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple
import transformers
from transformers.trainer import TRAINING_ARGS_NAME
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import get_device_count, torch_gc
from llmtuner.train import run_exp
from llmtuner.webui.common import get_module, get_save_dir, load_config
from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
from ..extras.callbacks import LogCallback
from ..extras.constants import TRAINING_STAGES
from ..extras.logging import LoggerHandler
from ..extras.misc import get_device_count, torch_gc
from ..train import run_exp
from .common import get_module, get_save_dir, load_config
from .locales import ALERTS
from .utils import gen_cmd, get_eval_results, update_process_bar
if TYPE_CHECKING:
from llmtuner.webui.manager import Manager
from .manager import Manager
class Runner:

View File

@@ -4,12 +4,12 @@ import gradio as gr
from typing import TYPE_CHECKING, Any, Dict
from datetime import datetime
from llmtuner.extras.packages import is_matplotlib_available
from llmtuner.extras.ploting import smooth
from llmtuner.webui.common import get_save_dir
from ..extras.packages import is_matplotlib_available
from ..extras.ploting import smooth
from .common import get_save_dir
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
from ..extras.callbacks import LogCallback
if is_matplotlib_available():
import matplotlib.figure