# coding=utf-8 # Implements user interface in browser for fine-tuned models. # Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint import gradio as gr from threading import Thread from transformers import TextIteratorStreamer from transformers.utils.versions import require_version from llmtuner import Template, get_infer_args, load_model_and_tokenizer, get_logits_processor require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") model_args, data_args, finetuning_args, generating_args = get_infer_args() model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) prompt_template = Template(data_args.prompt_template) source_prefix = data_args.source_prefix if data_args.source_prefix else "" def predict(query, chatbot, max_new_tokens, top_p, temperature, history): chatbot.append((query, "")) input_ids = tokenizer([prompt_template.get_prompt(query, history, source_prefix)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs = generating_args.to_dict() gen_kwargs.update({ "input_ids": input_ids, "top_p": top_p, "temperature": temperature, "max_new_tokens": max_new_tokens, "logits_processor": get_logits_processor(), "streamer": streamer }) thread = Thread(target=model.generate, kwargs=gen_kwargs) thread.start() response = "" for new_text in streamer: response += new_text new_history = history + [(query, response)] chatbot[-1] = (query, response) yield chatbot, new_history def reset_user_input(): return gr.update(value="") def reset_state(): return [], [] with gr.Blocks() as demo: gr.HTML("""