From eba31ae313f043b46ec69cfe5309ef69f38407d5 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 6 Mar 2025 16:52:21 +0800 Subject: [PATCH] [webui] support escape html (#7190) Former-commit-id: abb23f767351098a926202ea4edc94d9e9a4681c --- src/llamafactory/chat/hf_engine.py | 12 +++++-- src/llamafactory/chat/vllm_engine.py | 5 ++- src/llamafactory/webui/chatter.py | 30 +++++++++++++---- src/llamafactory/webui/components/chatbot.py | 22 +++++++++++-- src/llamafactory/webui/locales.py | 34 ++++++++++++++++++++ 5 files changed, 92 insertions(+), 11 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 897017e0..aee6080f 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -126,6 +126,7 @@ class HuggingfaceEngine(BaseEngine): num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) max_length: Optional[int] = input_kwargs.pop("max_length", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) @@ -145,6 +146,9 @@ class HuggingfaceEngine(BaseEngine): if repetition_penalty is not None else generating_args["repetition_penalty"], length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], + skip_special_tokens=skip_special_tokens + if skip_special_tokens is not None + else generating_args["skip_special_tokens"], eos_token_id=template.get_stop_token_ids(tokenizer), pad_token_id=tokenizer.pad_token_id, ) @@ -241,7 +245,9 @@ class HuggingfaceEngine(BaseEngine): response_ids = generate_output[:, prompt_length:] response = tokenizer.batch_decode( - response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True + response_ids, + skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True), + clean_up_tokenization_spaces=True, ) results = [] for i in range(len(response)): @@ -289,7 +295,9 @@ class HuggingfaceEngine(BaseEngine): input_kwargs, ) streamer = TextIteratorStreamer( - tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"] + tokenizer, + skip_prompt=True, + skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True), ) gen_kwargs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 894d3260..7888ea7b 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -139,6 +139,7 @@ class VllmEngine(BaseEngine): num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None) max_length: Optional[int] = input_kwargs.pop("max_length", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) @@ -172,7 +173,9 @@ class VllmEngine(BaseEngine): stop=stop, stop_token_ids=self.template.get_stop_token_ids(self.tokenizer), max_tokens=max_tokens, - skip_special_tokens=self.generating_args["skip_special_tokens"], + skip_special_tokens=skip_special_tokens + if skip_special_tokens is not None + else self.generating_args["skip_special_tokens"], ) if images is not None: # add image features diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index 354b80ea..944a3d06 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -36,14 +36,21 @@ if is_gradio_available(): import gradio as gr -def _format_response(text: str, lang: str, thought_words: Tuple[str, str] = ("", "")) -> str: +def _escape_html(text: str) -> str: + r""" + Escapes HTML characters. + """ + return text.replace("<", "<").replace(">", ">") + + +def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str: r""" Post-processes the response text. Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py """ if thought_words[0] not in text: - return text + return _escape_html(text) if escape_html else text text = text.replace(thought_words[0], "") result = text.split(thought_words[1], maxsplit=1) @@ -54,6 +61,9 @@ def _format_response(text: str, lang: str, thought_words: Tuple[str, str] = ("{summary}\n\n" f"
\n{thought}\n
\n{answer}" @@ -154,14 +164,19 @@ class WebChatModel(ChatModel): messages: List[Dict[str, str]], role: str, query: str, + escape_html: bool, ) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]: r""" Adds the user input to chatbot. - Inputs: infer.chatbot, infer.messages, infer.role, infer.query - Output: infer.chatbot, infer.messages + Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html + Output: infer.chatbot, infer.messages, infer.query """ - return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], "" + return ( + chatbot + [{"role": "user", "content": _escape_html(query) if escape_html else query}], + messages + [{"role": role, "content": query}], + "", + ) def stream( self, @@ -176,6 +191,8 @@ class WebChatModel(ChatModel): max_new_tokens: int, top_p: float, temperature: float, + skip_special_tokens: bool, + escape_html: bool, ) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]: r""" Generates output text in stream. @@ -195,6 +212,7 @@ class WebChatModel(ChatModel): max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, + skip_special_tokens=skip_special_tokens, ): response += new_text if tools: @@ -209,7 +227,7 @@ class WebChatModel(ChatModel): bot_text = "```json\n" + tool_calls + "\n```" else: output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}] - bot_text = _format_response(result, lang, self.engine.template.thought_words) + bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words) chatbot[-1] = {"role": "assistant", "content": bot_text} yield chatbot, output_messages diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py index 314a711d..51e9691d 100644 --- a/src/llamafactory/webui/components/chatbot.py +++ b/src/llamafactory/webui/components/chatbot.py @@ -79,17 +79,33 @@ def create_chat_box( max_new_tokens = gr.Slider(minimum=8, maximum=8192, value=1024, step=1) top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01) temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01) + skip_special_tokens = gr.Checkbox(value=True) + escape_html = gr.Checkbox(value=True) clear_btn = gr.Button() tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")]) submit_btn.click( engine.chatter.append, - [chatbot, messages, role, query], + [chatbot, messages, role, query, escape_html], [chatbot, messages, query], ).then( engine.chatter.stream, - [chatbot, messages, lang, system, tools, image, video, audio, max_new_tokens, top_p, temperature], + [ + chatbot, + messages, + lang, + system, + tools, + image, + video, + audio, + max_new_tokens, + top_p, + temperature, + skip_special_tokens, + escape_html, + ], [chatbot, messages], ) clear_btn.click(lambda: ([], []), outputs=[chatbot, messages]) @@ -111,6 +127,8 @@ def create_chat_box( max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature, + skip_special_tokens=skip_special_tokens, + escape_html=escape_html, clear_btn=clear_btn, ), ) diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index ec23e5ab..a0b4c03e 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -2412,6 +2412,40 @@ LOCALES = { "label": "温度", }, }, + "skip_special_tokens": { + "en": { + "label": "Skip special tokens", + }, + "ru": { + "label": "Пропустить специальные токены", + }, + "zh": { + "label": "跳过特殊 token", + }, + "ko": { + "label": "스페셜 토큰을 건너뛰기", + }, + "ja": { + "label": "スペシャルトークンをスキップ", + }, + }, + "escape_html": { + "en": { + "label": "Escape HTML tags", + }, + "ru": { + "label": "Исключить HTML теги", + }, + "zh": { + "label": "转义 HTML 标签", + }, + "ko": { + "label": "HTML 태그 이스케이프", + }, + "ja": { + "label": "HTML タグをエスケープ", + }, + }, "clear_btn": { "en": { "value": "Clear history",