mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
from typing import Any, Dict, Generator, List, Optional, Tuple
|
|
|
|
from llmtuner.chat.stream_chat import ChatModel
|
|
from llmtuner.extras.misc import torch_gc
|
|
from llmtuner.hparams import GeneratingArguments
|
|
from llmtuner.webui.common import get_model_path, get_save_dir
|
|
from llmtuner.webui.locales import ALERTS
|
|
|
|
|
|
class WebChatModel(ChatModel):
|
|
|
|
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.generating_args = GeneratingArguments()
|
|
if not lazy_init:
|
|
super().__init__(args)
|
|
|
|
def load_model(
|
|
self,
|
|
lang: str,
|
|
model_name: str,
|
|
checkpoints: List[str],
|
|
finetuning_type: str,
|
|
quantization_bit: str,
|
|
template: str,
|
|
system_prompt: str,
|
|
flash_attn: bool,
|
|
shift_attn: bool,
|
|
rope_scaling: str
|
|
) -> Generator[str, None, None]:
|
|
if self.model is not None:
|
|
yield ALERTS["err_exists"][lang]
|
|
return
|
|
|
|
if not model_name:
|
|
yield ALERTS["err_no_model"][lang]
|
|
return
|
|
|
|
model_name_or_path = get_model_path(model_name)
|
|
if not model_name_or_path:
|
|
yield ALERTS["err_no_path"][lang]
|
|
return
|
|
|
|
if checkpoints:
|
|
checkpoint_dir = ",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints])
|
|
else:
|
|
checkpoint_dir = None
|
|
|
|
yield ALERTS["info_loading"][lang]
|
|
args = dict(
|
|
model_name_or_path=model_name_or_path,
|
|
checkpoint_dir=checkpoint_dir,
|
|
finetuning_type=finetuning_type,
|
|
quantization_bit=int(quantization_bit) if quantization_bit in ["8", "4"] else None,
|
|
template=template,
|
|
system_prompt=system_prompt,
|
|
flash_attn=flash_attn,
|
|
shift_attn=shift_attn,
|
|
rope_scaling=rope_scaling if rope_scaling in ["linear", "dynamic"] else None
|
|
)
|
|
super().__init__(args)
|
|
|
|
yield ALERTS["info_loaded"][lang]
|
|
|
|
def unload_model(self, lang: str) -> Generator[str, None, None]:
|
|
yield ALERTS["info_unloading"][lang]
|
|
self.model = None
|
|
self.tokenizer = None
|
|
torch_gc()
|
|
yield ALERTS["info_unloaded"][lang]
|
|
|
|
def predict(
|
|
self,
|
|
chatbot: List[Tuple[str, str]],
|
|
query: str,
|
|
history: List[Tuple[str, str]],
|
|
system: str,
|
|
max_new_tokens: int,
|
|
top_p: float,
|
|
temperature: float
|
|
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
|
chatbot.append([query, ""])
|
|
response = ""
|
|
for new_text in self.stream_chat(
|
|
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
|
):
|
|
response += new_text
|
|
new_history = history + [(query, response)]
|
|
chatbot[-1] = [query, self.postprocess(response)]
|
|
yield chatbot, new_history
|
|
|
|
def postprocess(self, response: str) -> str:
|
|
blocks = response.split("```")
|
|
for i, block in enumerate(blocks):
|
|
if i % 2 == 0:
|
|
blocks[i] = block.replace("<", "<").replace(">", ">")
|
|
return "```".join(blocks)
|