mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-22 23:00:36 +08:00
[data] qwen3 fixes (#8109)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers.utils import is_torch_npu_available
|
||||
@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def update_attr(obj: Any, name: str, value: Any):
|
||||
old_value = getattr(obj, name, None)
|
||||
setattr(obj, name, value)
|
||||
yield
|
||||
setattr(obj, name, old_value)
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||
self.manager = manager
|
||||
@@ -198,35 +207,35 @@ class WebChatModel(ChatModel):
|
||||
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
|
||||
Output: infer.chatbot, infer.messages
|
||||
"""
|
||||
chatbot.append({"role": "assistant", "content": ""})
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images=[image] if image else None,
|
||||
videos=[video] if video else None,
|
||||
audios=[audio] if audio else None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
enable_thinking=enable_thinking,
|
||||
):
|
||||
response += new_text
|
||||
if tools:
|
||||
result = self.engine.template.extract_tool(response)
|
||||
else:
|
||||
result = response
|
||||
with update_attr(self.engine.template, "enable_thinking", enable_thinking):
|
||||
chatbot.append({"role": "assistant", "content": ""})
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
messages,
|
||||
system,
|
||||
tools,
|
||||
images=[image] if image else None,
|
||||
videos=[video] if video else None,
|
||||
audios=[audio] if audio else None,
|
||||
max_new_tokens=max_new_tokens,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
):
|
||||
response += new_text
|
||||
if tools:
|
||||
result = self.engine.template.extract_tool(response)
|
||||
else:
|
||||
result = response
|
||||
|
||||
if isinstance(result, list):
|
||||
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
|
||||
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
|
||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
||||
bot_text = "```json\n" + tool_calls + "\n```"
|
||||
else:
|
||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
|
||||
if isinstance(result, list):
|
||||
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
|
||||
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
|
||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
||||
bot_text = "```json\n" + tool_calls + "\n```"
|
||||
else:
|
||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
|
||||
|
||||
chatbot[-1] = {"role": "assistant", "content": bot_text}
|
||||
yield chatbot, output_messages
|
||||
chatbot[-1] = {"role": "assistant", "content": bot_text}
|
||||
yield chatbot, output_messages
|
||||
|
||||
Reference in New Issue
Block a user