diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py
index 897017e0..aee6080f 100644
--- a/src/llamafactory/chat/hf_engine.py
+++ b/src/llamafactory/chat/hf_engine.py
@@ -126,6 +126,7 @@ class HuggingfaceEngine(BaseEngine):
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
@@ -145,6 +146,9 @@ class HuggingfaceEngine(BaseEngine):
if repetition_penalty is not None
else generating_args["repetition_penalty"],
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
+ skip_special_tokens=skip_special_tokens
+ if skip_special_tokens is not None
+ else generating_args["skip_special_tokens"],
eos_token_id=template.get_stop_token_ids(tokenizer),
pad_token_id=tokenizer.pad_token_id,
)
@@ -241,7 +245,9 @@ class HuggingfaceEngine(BaseEngine):
response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(
- response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
+ response_ids,
+ skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
+ clean_up_tokenization_spaces=True,
)
results = []
for i in range(len(response)):
@@ -289,7 +295,9 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs,
)
streamer = TextIteratorStreamer(
- tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"]
+ tokenizer,
+ skip_prompt=True,
+ skip_special_tokens=getattr(gen_kwargs["generation_config"], "skip_special_tokens", True),
)
gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py
index 894d3260..7888ea7b 100644
--- a/src/llamafactory/chat/vllm_engine.py
+++ b/src/llamafactory/chat/vllm_engine.py
@@ -139,6 +139,7 @@ class VllmEngine(BaseEngine):
num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1)
repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None)
length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None)
+ skip_special_tokens: Optional[bool] = input_kwargs.pop("skip_special_tokens", None)
max_length: Optional[int] = input_kwargs.pop("max_length", None)
max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None)
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
@@ -172,7 +173,9 @@ class VllmEngine(BaseEngine):
stop=stop,
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
max_tokens=max_tokens,
- skip_special_tokens=self.generating_args["skip_special_tokens"],
+ skip_special_tokens=skip_special_tokens
+ if skip_special_tokens is not None
+ else self.generating_args["skip_special_tokens"],
)
if images is not None: # add image features
diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py
index 354b80ea..944a3d06 100644
--- a/src/llamafactory/webui/chatter.py
+++ b/src/llamafactory/webui/chatter.py
@@ -36,14 +36,21 @@ if is_gradio_available():
import gradio as gr
-def _format_response(text: str, lang: str, thought_words: Tuple[str, str] = ("", "")) -> str:
+def _escape_html(text: str) -> str:
+ r"""
+ Escapes HTML characters.
+ """
+ return text.replace("<", "<").replace(">", ">")
+
+
+def _format_response(text: str, lang: str, escape_html: bool, thought_words: Tuple[str, str]) -> str:
r"""
Post-processes the response text.
Based on: https://huggingface.co/spaces/Lyte/DeepSeek-R1-Distill-Qwen-1.5B-Demo-GGUF/blob/main/app.py
"""
if thought_words[0] not in text:
- return text
+ return _escape_html(text) if escape_html else text
text = text.replace(thought_words[0], "")
result = text.split(thought_words[1], maxsplit=1)
@@ -54,6 +61,9 @@ def _format_response(text: str, lang: str, thought_words: Tuple[str, str] = ("{summary}\n\n"
f"\n{thought}\n
\n{answer}"
@@ -154,14 +164,19 @@ class WebChatModel(ChatModel):
messages: List[Dict[str, str]],
role: str,
query: str,
+ escape_html: bool,
) -> Tuple[List[Dict[str, str]], List[Dict[str, str]], str]:
r"""
Adds the user input to chatbot.
- Inputs: infer.chatbot, infer.messages, infer.role, infer.query
- Output: infer.chatbot, infer.messages
+ Inputs: infer.chatbot, infer.messages, infer.role, infer.query, infer.escape_html
+ Output: infer.chatbot, infer.messages, infer.query
"""
- return chatbot + [{"role": "user", "content": query}], messages + [{"role": role, "content": query}], ""
+ return (
+ chatbot + [{"role": "user", "content": _escape_html(query) if escape_html else query}],
+ messages + [{"role": role, "content": query}],
+ "",
+ )
def stream(
self,
@@ -176,6 +191,8 @@ class WebChatModel(ChatModel):
max_new_tokens: int,
top_p: float,
temperature: float,
+ skip_special_tokens: bool,
+ escape_html: bool,
) -> Generator[Tuple[List[Dict[str, str]], List[Dict[str, str]]], None, None]:
r"""
Generates output text in stream.
@@ -195,6 +212,7 @@ class WebChatModel(ChatModel):
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
+ skip_special_tokens=skip_special_tokens,
):
response += new_text
if tools:
@@ -209,7 +227,7 @@ class WebChatModel(ChatModel):
bot_text = "```json\n" + tool_calls + "\n```"
else:
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
- bot_text = _format_response(result, lang, self.engine.template.thought_words)
+ bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
chatbot[-1] = {"role": "assistant", "content": bot_text}
yield chatbot, output_messages
diff --git a/src/llamafactory/webui/components/chatbot.py b/src/llamafactory/webui/components/chatbot.py
index 314a711d..51e9691d 100644
--- a/src/llamafactory/webui/components/chatbot.py
+++ b/src/llamafactory/webui/components/chatbot.py
@@ -79,17 +79,33 @@ def create_chat_box(
max_new_tokens = gr.Slider(minimum=8, maximum=8192, value=1024, step=1)
top_p = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.01)
temperature = gr.Slider(minimum=0.01, maximum=1.5, value=0.95, step=0.01)
+ skip_special_tokens = gr.Checkbox(value=True)
+ escape_html = gr.Checkbox(value=True)
clear_btn = gr.Button()
tools.input(check_json_schema, inputs=[tools, engine.manager.get_elem_by_id("top.lang")])
submit_btn.click(
engine.chatter.append,
- [chatbot, messages, role, query],
+ [chatbot, messages, role, query, escape_html],
[chatbot, messages, query],
).then(
engine.chatter.stream,
- [chatbot, messages, lang, system, tools, image, video, audio, max_new_tokens, top_p, temperature],
+ [
+ chatbot,
+ messages,
+ lang,
+ system,
+ tools,
+ image,
+ video,
+ audio,
+ max_new_tokens,
+ top_p,
+ temperature,
+ skip_special_tokens,
+ escape_html,
+ ],
[chatbot, messages],
)
clear_btn.click(lambda: ([], []), outputs=[chatbot, messages])
@@ -111,6 +127,8 @@ def create_chat_box(
max_new_tokens=max_new_tokens,
top_p=top_p,
temperature=temperature,
+ skip_special_tokens=skip_special_tokens,
+ escape_html=escape_html,
clear_btn=clear_btn,
),
)
diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py
index ec23e5ab..a0b4c03e 100644
--- a/src/llamafactory/webui/locales.py
+++ b/src/llamafactory/webui/locales.py
@@ -2412,6 +2412,40 @@ LOCALES = {
"label": "温度",
},
},
+ "skip_special_tokens": {
+ "en": {
+ "label": "Skip special tokens",
+ },
+ "ru": {
+ "label": "Пропустить специальные токены",
+ },
+ "zh": {
+ "label": "跳过特殊 token",
+ },
+ "ko": {
+ "label": "스페셜 토큰을 건너뛰기",
+ },
+ "ja": {
+ "label": "スペシャルトークンをスキップ",
+ },
+ },
+ "escape_html": {
+ "en": {
+ "label": "Escape HTML tags",
+ },
+ "ru": {
+ "label": "Исключить HTML теги",
+ },
+ "zh": {
+ "label": "转义 HTML 标签",
+ },
+ "ko": {
+ "label": "HTML 태그 이스케이프",
+ },
+ "ja": {
+ "label": "HTML タグをエスケープ",
+ },
+ },
"clear_btn": {
"en": {
"value": "Clear history",