from typing import Any, Dict, Generator, List, Optional, Tuple from llmtuner.chat.stream_chat import ChatModel from llmtuner.extras.misc import torch_gc from llmtuner.hparams import GeneratingArguments from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.locales import ALERTS class WebChatModel(ChatModel): def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None: self.model = None self.tokenizer = None self.generating_args = GeneratingArguments() if not lazy_init: super().__init__(args) def load_model( self, lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, quantization_bit: str, template: str, system_prompt: str, flash_attn: bool, shift_attn: bool, rope_scaling: str ) -> Generator[str, None, None]: if self.model is not None: yield ALERTS["err_exists"][lang] return if not model_name: yield ALERTS["err_no_model"][lang] return model_name_or_path = get_model_path(model_name) if not model_name_or_path: yield ALERTS["err_no_path"][lang] return if checkpoints: checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]) else: checkpoint_dir = None yield ALERTS["info_loading"][lang] args = dict( model_name_or_path=model_name_or_path, checkpoint_dir=checkpoint_dir, finetuning_type=finetuning_type, quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None, template=template, system_prompt=system_prompt, flash_attn=flash_attn, shift_attn=shift_attn, rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None ) super().__init__(args) yield ALERTS["info_loaded"][lang] def unload_model(self, lang: str) -> Generator[str, None, None]: yield ALERTS["info_unloading"][lang] self.model = None self.tokenizer = None torch_gc() yield ALERTS["info_unloaded"][lang] def predict( self, chatbot: List[Tuple[str, str]], query: str, history: List[Tuple[str, str]], system: str, max_new_tokens: int, top_p: float, temperature: float ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]: chatbot.append([query, ""]) response = "" for new_text in self.stream_chat( query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ): response += new_text new_history = history + [(query, response)] chatbot[-1] = [query, self.postprocess(response)] yield chatbot, new_history def postprocess(self, response: str) -> str: blocks = response.split("```") for i, block in enumerate(blocks): if i % 2 == 0: blocks[i] = block.replace("<", "<").replace(">", ">") return "```".join(blocks)