mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-05 05:02:50 +08:00
98 lines
3.0 KiB
Python
98 lines
3.0 KiB
Python
import os
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
from llmtuner.chat.stream_chat import ChatModel
|
|
from llmtuner.extras.misc import torch_gc
|
|
from llmtuner.hparams import GeneratingArguments
|
|
from llmtuner.webui.common import get_model_path, get_save_dir
|
|
from llmtuner.webui.locales import ALERTS
|
|
|
|
|
|
class WebChatModel(ChatModel):
|
|
|
|
def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
|
|
if lazy_init:
|
|
self.model = None
|
|
self.tokenizer = None
|
|
self.generating_args = GeneratingArguments()
|
|
else:
|
|
super().__init__(args)
|
|
|
|
def load_model(
|
|
self,
|
|
lang: str,
|
|
model_name: str,
|
|
checkpoints: List[str],
|
|
finetuning_type: str,
|
|
quantization_bit: str,
|
|
template: str,
|
|
system_prompt: str
|
|
):
|
|
if self.model is not None:
|
|
yield ALERTS["err_exists"][lang]
|
|
return
|
|
|
|
if not model_name:
|
|
yield ALERTS["err_no_model"][lang]
|
|
return
|
|
|
|
model_name_or_path = get_model_path(model_name)
|
|
if not model_name_or_path:
|
|
yield ALERTS["err_no_path"][lang]
|
|
return
|
|
|
|
if checkpoints:
|
|
checkpoint_dir = ",".join(
|
|
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
|
)
|
|
else:
|
|
checkpoint_dir = None
|
|
|
|
yield ALERTS["info_loading"][lang]
|
|
args = dict(
|
|
model_name_or_path=model_name_or_path,
|
|
checkpoint_dir=checkpoint_dir,
|
|
finetuning_type=finetuning_type,
|
|
quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None,
|
|
template=template,
|
|
system_prompt=system_prompt
|
|
)
|
|
super().__init__(args)
|
|
|
|
yield ALERTS["info_loaded"][lang]
|
|
|
|
def unload_model(self, lang: str):
|
|
yield ALERTS["info_unloading"][lang]
|
|
self.model = None
|
|
self.tokenizer = None
|
|
torch_gc()
|
|
yield ALERTS["info_unloaded"][lang]
|
|
|
|
def predict(
|
|
self,
|
|
chatbot: List[Tuple[str, str]],
|
|
query: str,
|
|
history: List[Tuple[str, str]],
|
|
system: str,
|
|
max_new_tokens: int,
|
|
top_p: float,
|
|
temperature: float
|
|
):
|
|
chatbot.append([query, ""])
|
|
response = ""
|
|
for new_text in self.stream_chat(
|
|
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
|
):
|
|
response += new_text
|
|
response = self.postprocess(response)
|
|
new_history = history + [(query, response)]
|
|
chatbot[-1] = [query, response]
|
|
yield chatbot, new_history
|
|
|
|
def postprocess(self, response: str) -> str:
|
|
blocks = response.split("```")
|
|
for i, block in enumerate(blocks):
|
|
if i % 2 == 0:
|
|
blocks[i] = block.replace("<", "<").replace(">", ">")
|
|
return "```".join(blocks)
|