Former-commit-id: e1dcb8e4dc958a677bf484e27aec43b9710d7287
This commit is contained in:
hiyouga 2023-10-10 17:41:13 +08:00
parent 3ba788fc2c
commit 7082526df5
9 changed files with 22 additions and 29 deletions

View File

@ -1,5 +1,4 @@
import os from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
@ -11,11 +10,10 @@ from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
if lazy_init:
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()
else: if not lazy_init:
super().__init__(args) super().__init__(args)
def load_model( def load_model(
@ -30,7 +28,7 @@ class WebChatModel(ChatModel):
flash_attn: bool, flash_attn: bool,
shift_attn: bool, shift_attn: bool,
rope_scaling: str rope_scaling: str
): ) -> Generator[str, None, None]:
if self.model is not None: if self.model is not None:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
return return
@ -65,7 +63,7 @@ class WebChatModel(ChatModel):
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]
def unload_model(self, lang: str): def unload_model(self, lang: str) -> Generator[str, None, None]:
yield ALERTS["info_unloading"][lang] yield ALERTS["info_unloading"][lang]
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
@ -81,16 +79,15 @@ class WebChatModel(ChatModel):
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float
): ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
chatbot.append([query, ""]) chatbot.append([query, ""])
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
response = self.postprocess(response)
new_history = history + [(query, response)] new_history = history + [(query, response)]
chatbot[-1] = [query, response] chatbot[-1] = [query, self.postprocess(response)]
yield chatbot, new_history yield chatbot, new_history
def postprocess(self, response: str) -> str: def postprocess(self, response: str) -> str:

View File

@ -1,8 +1,7 @@
import json
import os import os
from typing import Any, Dict, Optional import json
import gradio as gr import gradio as gr
from typing import Any, Dict, Optional
from transformers.utils import ( from transformers.utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,

View File

@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict, Optional, Tuple
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.blocks import Block from gradio.blocks import Block

View File

@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.components.data import create_preview_box

View File

@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.utils import save_model from llmtuner.webui.utils import save_model

View File

@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.webui.chat import WebChatModel from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box from llmtuner.webui.components.chatbot import create_chat_box

View File

@ -1,6 +1,5 @@
from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from typing import TYPE_CHECKING, Dict
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates from llmtuner.extras.template import templates

View File

@ -1,8 +1,7 @@
import gradio as gr
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
from transformers.trainer_utils import SchedulerType from transformers.trainer_utils import SchedulerType
import gradio as gr
from llmtuner.extras.constants import TRAINING_STAGES from llmtuner.extras.constants import TRAINING_STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.components.data import create_preview_box

View File

@ -1,11 +1,12 @@
import gradio as gr
import logging
import os import os
import threading
import time import time
import logging
import threading
import gradio as gr
from typing import Any, Dict, Generator, List, Tuple
import transformers import transformers
from transformers.trainer import TRAINING_ARGS_NAME from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES from llmtuner.extras.constants import DEFAULT_MODULE, TRAINING_STAGES