diff --git a/README.md b/README.md
index 9e76b58e..2289c535 100644
--- a/README.md
+++ b/README.md
@@ -250,7 +250,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
-> [!NOTE]
+> [!TIP]
> If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again.
## Supported Models
diff --git a/README_zh.md b/README_zh.md
index 90ad5a0d..b4219d9c 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -237,7 +237,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
-> [!NOTE]
+> [!TIP]
> 如果您无法使用最新的功能,请尝试重新拉取代码并再次安装 LLaMA-Factory。
## 模型
diff --git a/data/README.md b/data/README.md
index 5c2e969a..90503351 100644
--- a/data/README.md
+++ b/data/README.md
@@ -50,7 +50,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
* [Example dataset](alpaca_en_demo.json)
-In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the human prompt, then the human prompt would be `instruction\ninput`. The `output` column represents the model response.
+In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the user prompt, then the user prompt would be `instruction\ninput`. The `output` column represents the model response.
+
+For reasoning models, if the dataset contains chain-of-thought (CoT), the CoT needs to be placed in the model responses, such as `cot output`.
The `system` column will be used as the system prompt if specified.
@@ -59,13 +61,13 @@ The `history` column is a list consisting of string tuples representing prompt-r
```json
[
{
- "instruction": "human instruction (required)",
- "input": "human input (optional)",
+ "instruction": "user instruction (required)",
+ "input": "user input (optional)",
"output": "model response (required)",
"system": "system prompt (optional)",
"history": [
- ["human instruction in the first round (optional)", "model response in the first round (optional)"],
- ["human instruction in the second round (optional)", "model response in the second round (optional)"]
+ ["user instruction in the first round (optional)", "model response in the first round (optional)"],
+ ["user instruction in the second round (optional)", "model response in the second round (optional)"]
]
}
]
@@ -86,6 +88,9 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
}
```
+> [!TIP]
+> If the model has reasoning capabilities but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True`, the empty CoT will be added to the model responses and loss computation will be considered; otherwise, it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference.
+
### Pre-training Dataset
- [Example dataset](c4_demo.jsonl)
@@ -119,8 +124,8 @@ It requires a better response in `chosen` column and a worse response in `reject
```json
[
{
- "instruction": "human instruction (required)",
- "input": "human input (optional)",
+ "instruction": "user instruction (required)",
+ "input": "user input (optional)",
"chosen": "chosen answer (required)",
"rejected": "rejected answer (required)"
}
@@ -174,7 +179,7 @@ Note that the human and observation should appear in odd positions, while gpt an
"conversations": [
{
"from": "human",
- "value": "human instruction"
+ "value": "user instruction"
},
{
"from": "function_call",
@@ -225,7 +230,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
"conversations": [
{
"from": "human",
- "value": "human instruction"
+ "value": "user instruction"
},
{
"from": "gpt",
@@ -233,7 +238,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
},
{
"from": "human",
- "value": "human instruction"
+ "value": "user instruction"
}
],
"chosen": {
@@ -275,7 +280,7 @@ KTO datasets require a extra `kto_tag` column containing the boolean human feedb
"conversations": [
{
"from": "human",
- "value": "human instruction"
+ "value": "user instruction"
},
{
"from": "gpt",
@@ -314,7 +319,7 @@ The number of images should be identical to the `` tokens in the conversa
"conversations": [
{
"from": "human",
- "value": "human instruction"
+ "value": "user instruction"
},
{
"from": "gpt",
@@ -355,7 +360,7 @@ The number of videos should be identical to the `` tokens in the conversa
"conversations": [
{
"from": "human",
- "value": "human instruction"
+ "value": "user instruction"
},
{
"from": "gpt",
@@ -396,7 +401,7 @@ The number of audios should be identical to the `` tokens in the conversa
"conversations": [
{
"from": "human",
- "value": "human instruction"
+ "value": "user instruction"
},
{
"from": "gpt",
@@ -437,7 +442,7 @@ The openai format is simply a special case of the sharegpt format, where the fir
},
{
"role": "user",
- "content": "human instruction"
+ "content": "user instruction"
},
{
"role": "assistant",
diff --git a/data/README_zh.md b/data/README_zh.md
index e36cbfe6..f26725ca 100644
--- a/data/README_zh.md
+++ b/data/README_zh.md
@@ -49,7 +49,9 @@
- [样例数据集](alpaca_zh_demo.json)
-在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为人类指令,即人类指令为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
+在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为提示词,即提示词为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
+
+对于推理类模型的微调,如果数据集包含思维链,则需要把思维链放在模型回答中,例如 `cot output`。
如果指定,`system` 列对应的内容将被作为系统提示词。
@@ -58,8 +60,8 @@
```json
[
{
- "instruction": "人类指令(必填)",
- "input": "人类输入(选填)",
+ "instruction": "用户指令(必填)",
+ "input": "用户输入(选填)",
"output": "模型回答(必填)",
"system": "系统提示词(选填)",
"history": [
@@ -85,6 +87,9 @@
}
```
+> [!TIP]
+> 如果模型本身具备推理能力,而数据集不包含思维链,LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking` 为 `True` 时,空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失。请在训练和推理时保持 `enable_thinking` 参数一致。
+
### 预训练数据集
- [样例数据集](c4_demo.jsonl)
@@ -118,8 +123,8 @@
```json
[
{
- "instruction": "人类指令(必填)",
- "input": "人类输入(选填)",
+ "instruction": "用户指令(必填)",
+ "input": "用户输入(选填)",
"chosen": "优质回答(必填)",
"rejected": "劣质回答(必填)"
}
@@ -173,7 +178,7 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
"conversations": [
{
"from": "human",
- "value": "人类指令"
+ "value": "用户指令"
},
{
"from": "function_call",
@@ -224,7 +229,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
"conversations": [
{
"from": "human",
- "value": "人类指令"
+ "value": "用户指令"
},
{
"from": "gpt",
@@ -232,7 +237,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
},
{
"from": "human",
- "value": "人类指令"
+ "value": "用户指令"
}
],
"chosen": {
@@ -274,7 +279,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [
{
"from": "human",
- "value": "人类指令"
+ "value": "用户指令"
},
{
"from": "gpt",
@@ -313,7 +318,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [
{
"from": "human",
- "value": "人类指令"
+ "value": "用户指令"
},
{
"from": "gpt",
@@ -321,6 +326,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
}
],
"images": [
+ "图像路径(必填)",
"图像路径(必填)"
]
}
@@ -354,7 +360,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [
{
"from": "human",
- "value": "人类指令"
+ "value": "用户指令"
},
{
"from": "gpt",
@@ -362,6 +368,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
}
],
"videos": [
+ "视频路径(必填)",
"视频路径(必填)"
]
}
@@ -395,7 +402,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [
{
"from": "human",
- "value": "人类指令"
+ "value": "用户指令"
},
{
"from": "gpt",
@@ -403,6 +410,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
}
],
"audios": [
+ "音频路径(必填)",
"音频路径(必填)"
]
}
@@ -437,7 +445,7 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
},
{
"role": "user",
- "content": "人类指令"
+ "content": "用户指令"
},
{
"role": "assistant",
diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py
index b91634bb..1a080ad5 100644
--- a/scripts/vllm_infer.py
+++ b/scripts/vllm_infer.py
@@ -49,6 +49,8 @@ def vllm_infer(
max_new_tokens: int = 1024,
repetition_penalty: float = 1.0,
skip_special_tokens: bool = True,
+ default_system: Optional[str] = None,
+ enable_thinking: bool = True,
seed: Optional[int] = None,
pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768,
@@ -74,6 +76,8 @@ def vllm_infer(
cutoff_len=cutoff_len,
max_samples=max_samples,
preprocessing_num_workers=16,
+ default_system=default_system,
+ enable_thinking=enable_thinking,
vllm_config=vllm_config,
temperature=temperature,
top_p=top_p,
@@ -127,14 +131,11 @@ def vllm_infer(
lora_request = None
# Store all results in these lists
- all_prompts = []
- all_preds = []
- all_labels = []
+ all_prompts, all_preds, all_labels = [], [], []
# Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], []
-
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])):
@@ -176,15 +177,14 @@ def vllm_infer(
)
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
-
preds = [result.outputs[0].text for result in results]
# Accumulate results
all_prompts.extend(prompts)
all_preds.extend(preds)
all_labels.extend(labels)
-
gc.collect()
+
# Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(all_prompts, all_preds, all_labels):
diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py
index 5ed47886..adaaaa87 100644
--- a/src/llamafactory/chat/hf_engine.py
+++ b/src/llamafactory/chat/hf_engine.py
@@ -104,10 +104,7 @@ class HuggingfaceEngine(BaseEngine):
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
- system = system or generating_args["default_system"]
- enable_thinking = input_kwargs.pop("enable_thinking", None)
- enable_thinking = enable_thinking if enable_thinking is not None else generating_args["enable_thinking"]
- prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools, enable_thinking)
+ prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids,
None,
diff --git a/src/llamafactory/chat/sglang_engine.py b/src/llamafactory/chat/sglang_engine.py
index 99ca04ae..b1d2ead3 100644
--- a/src/llamafactory/chat/sglang_engine.py
+++ b/src/llamafactory/chat/sglang_engine.py
@@ -160,10 +160,7 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
- system = system or self.generating_args["default_system"]
- enable_thinking = input_kwargs.pop("enable_thinking", None)
- enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"]
- prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py
index 9110ae05..b3370527 100644
--- a/src/llamafactory/chat/vllm_engine.py
+++ b/src/llamafactory/chat/vllm_engine.py
@@ -124,10 +124,7 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor
)
paired_messages = messages + [{"role": "assistant", "content": ""}]
- system = system or self.generating_args["default_system"]
- enable_thinking = input_kwargs.pop("enable_thinking", None)
- enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"]
- prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
+ prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None)
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index b05c3b86..aa422cab 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -52,6 +52,7 @@ class Template:
efficient_eos: bool
replace_eos: bool
replace_jinja_template: bool
+ enable_thinking: bool
mm_plugin: "BasePlugin"
def encode_oneturn(
@@ -60,7 +61,6 @@ class Template:
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
- enable_thinking: bool = False,
) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools)
@@ -94,7 +94,7 @@ class Template:
return list(stop_token_ids)
- def add_thought(self, content: str) -> str:
+ def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
@@ -105,7 +105,7 @@ class Template:
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words."""
- return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False)
+ return tokenizer.encode(self.add_thought(), add_special_tokens=False)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids."""
@@ -406,26 +406,21 @@ class ReasoningTemplate(Template):
messages: list[dict[str, str]],
system: Optional[str] = None,
tools: Optional[str] = None,
- enable_thinking: bool = False,
) -> tuple[list[int], list[int]]:
messages = deepcopy(messages)
- for i in range(len(messages)):
- if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1):
- messages[i]["content"] = self.remove_thought(messages[i]["content"])
+ for i in range(1, len(messages) - 2, 2):
+ messages[i]["content"] = self.remove_thought(messages[i]["content"])
- encoded_messages = self._encode(tokenizer, messages, system, tools)
- prompt_ids = []
- for encoded_ids in encoded_messages[:-1]:
- prompt_ids += encoded_ids
-
- if not enable_thinking and (
- messages[-1]["role"] == Role.ASSISTANT
- and self.thought_words[0] not in messages[-1]["content"]
+ prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
+ if (
+ self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"]
):
- prompt_ids += self.get_thought_word_ids(tokenizer)
+ if not self.enable_thinking:
+ prompt_ids = prompt_ids + self.get_thought_word_ids(tokenizer)
+ else:
+ response_ids = self.get_thought_word_ids(tokenizer) + response_ids
- response_ids = encoded_messages[-1]
return prompt_ids, response_ids
@override
@@ -436,15 +431,16 @@ class ReasoningTemplate(Template):
system: Optional[str] = None,
tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]:
- messages = deepcopy(messages)
encoded_messages = self._encode(tokenizer, messages, system, tools)
- for i in range(len(messages) - 1):
+ for i in range(0, len(messages), 2):
if (
- messages[i + 1]["role"] == Role.ASSISTANT
- and self.thought_words[0] not in messages[i + 1]["content"]
+ self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"]
):
- encoded_messages[i] += self.get_thought_word_ids(tokenizer)
+ if not self.enable_thinking:
+ encoded_messages[i] += self.get_thought_word_ids(tokenizer)
+ else:
+ encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
@@ -467,6 +463,7 @@ def register_template(
efficient_eos: bool = False,
replace_eos: bool = False,
replace_jinja_template: bool = False,
+ enable_thinking: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template,
) -> None:
@@ -513,6 +510,7 @@ def register_template(
efficient_eos=efficient_eos,
replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template,
+ enable_thinking=enable_thinking,
mm_plugin=mm_plugin,
)
@@ -549,6 +547,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
+ template_class = ReasoningTemplate if "" in assistant_slot else Template
assistant_slot = assistant_slot.replace("", "").replace(" ", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system):
@@ -558,7 +557,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = ""
- return Template(
+ return template_class(
format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]),
@@ -572,6 +571,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos=False,
replace_eos=False,
replace_jinja_template=False,
+ enable_thinking=True,
mm_plugin=get_mm_plugin(name="base"),
)
@@ -600,6 +600,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
+ if data_args.default_system is not None:
+ logger.info_rank0(f"Using default system message: {data_args.default_system}.")
+ template.default_system = data_args.default_system
+
+ template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer)
return template
diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py
index 60d3036e..588b8c5c 100644
--- a/src/llamafactory/hparams/data_args.py
+++ b/src/llamafactory/hparams/data_args.py
@@ -115,6 +115,14 @@ class DataArguments:
default=None,
metadata={"help": "Tool format to use for constructing function calling examples."},
)
+ default_system: Optional[str] = field(
+ default=None,
+ metadata={"help": "Override the default system message in the template."},
+ )
+ enable_thinking: bool = field(
+ default=True,
+ metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
+ )
tokenized_path: Optional[str] = field(
default=None,
metadata={
diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py
index ac377543..7eacb147 100644
--- a/src/llamafactory/hparams/generating_args.py
+++ b/src/llamafactory/hparams/generating_args.py
@@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import asdict, dataclass, field
-from typing import Any, Optional
+from typing import Any
from transformers import GenerationConfig
@@ -62,18 +62,10 @@ class GeneratingArguments:
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
)
- default_system: Optional[str] = field(
- default=None,
- metadata={"help": "Default system message to use in chat completion."},
- )
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
- enable_thinking: bool = field(
- default=True,
- metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
- )
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self)
diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py
index 9f5733b3..9b1bcd67 100644
--- a/src/llamafactory/webui/chatter.py
+++ b/src/llamafactory/webui/chatter.py
@@ -15,6 +15,7 @@
import json
import os
from collections.abc import Generator
+from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available
@@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
)
+@contextmanager
+def update_attr(obj: Any, name: str, value: Any):
+ old_value = getattr(obj, name, None)
+ setattr(obj, name, value)
+ yield
+ setattr(obj, name, old_value)
+
+
class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager
@@ -198,35 +207,35 @@ class WebChatModel(ChatModel):
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages
"""
- chatbot.append({"role": "assistant", "content": ""})
- response = ""
- for new_text in self.stream_chat(
- messages,
- system,
- tools,
- images=[image] if image else None,
- videos=[video] if video else None,
- audios=[audio] if audio else None,
- max_new_tokens=max_new_tokens,
- top_p=top_p,
- temperature=temperature,
- skip_special_tokens=skip_special_tokens,
- enable_thinking=enable_thinking,
- ):
- response += new_text
- if tools:
- result = self.engine.template.extract_tool(response)
- else:
- result = response
+ with update_attr(self.engine.template, "enable_thinking", enable_thinking):
+ chatbot.append({"role": "assistant", "content": ""})
+ response = ""
+ for new_text in self.stream_chat(
+ messages,
+ system,
+ tools,
+ images=[image] if image else None,
+ videos=[video] if video else None,
+ audios=[audio] if audio else None,
+ max_new_tokens=max_new_tokens,
+ top_p=top_p,
+ temperature=temperature,
+ skip_special_tokens=skip_special_tokens,
+ ):
+ response += new_text
+ if tools:
+ result = self.engine.template.extract_tool(response)
+ else:
+ result = response
- if isinstance(result, list):
- tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
- tool_calls = json.dumps(tool_calls, ensure_ascii=False)
- output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
- bot_text = "```json\n" + tool_calls + "\n```"
- else:
- output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
- bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
+ if isinstance(result, list):
+ tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
+ tool_calls = json.dumps(tool_calls, ensure_ascii=False)
+ output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
+ bot_text = "```json\n" + tool_calls + "\n```"
+ else:
+ output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
+ bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
- chatbot[-1] = {"role": "assistant", "content": bot_text}
- yield chatbot, output_messages
+ chatbot[-1] = {"role": "assistant", "content": bot_text}
+ yield chatbot, output_messages
diff --git a/tests/data/test_template.py b/tests/data/test_template.py
index e74e8b45..eef52efe 100644
--- a/tests/data/test_template.py
+++ b/tests/data/test_template.py
@@ -126,29 +126,50 @@ def test_encode_multiturn(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False])
-def test_reasoning_encode_oneturn(use_fast: bool):
+@pytest.mark.parametrize("cot_messages", [True, False])
+@pytest.mark.parametrize("enable_thinking", [True, False])
+def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
+ messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
- template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
- prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
+ data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
+ template = get_template_and_fix_tokenizer(tokenizer, data_args)
+ prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
prompt_str = (
- "<|im_start|>user\nHow are you<|im_end|>\n"
- "<|im_start|>assistant\nI am fine!<|im_end|>\n"
- "<|im_start|>user\n你好<|im_end|>\n"
- "<|im_start|>assistant\n\n\n \n\n"
+ f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
+ f"{MESSAGES[1]['content']}<|im_end|>\n"
+ f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
- answer_str = "很高兴认识你!<|im_end|>\n"
+ answer_str = f"{messages[3]['content']}<|im_end|>\n"
+ if not cot_messages:
+ if enable_thinking:
+ answer_str = "\n\n \n\n" + answer_str
+ else:
+ prompt_str = prompt_str + "\n\n \n\n"
+
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False])
-def test_reasoning_encode_multiturn(use_fast: bool):
+@pytest.mark.parametrize("cot_messages", [True, False])
+@pytest.mark.parametrize("enable_thinking", [True, False])
+def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
+ messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
- template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
- encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
- prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n\n\n \n\n"
- answer_str_1 = "I am fine!<|im_end|>\n"
- prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n\n\n \n\n"
- answer_str_2 = "很高兴认识你!<|im_end|>\n"
+ data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
+ template = get_template_and_fix_tokenizer(tokenizer, data_args)
+ encoded_pairs = template.encode_multiturn(tokenizer, messages)
+ prompt_str_1 = f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
+ answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
+ prompt_str_2 = f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
+ answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
+ if not cot_messages:
+ if enable_thinking:
+ answer_str_1 = "\n\n \n\n" + answer_str_1
+ answer_str_2 = "\n\n \n\n" + answer_str_2
+ else:
+ prompt_str_1 = prompt_str_1 + "\n\n \n\n"
+ prompt_str_2 = prompt_str_2 + "\n\n \n\n"
+
_check_tokenization(
tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
@@ -193,12 +214,12 @@ def test_get_stop_token_ids():
@pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool):
prompt_str = (
- "user\nHow are you\n"
- "model\nI am fine!\n"
- "user\n你好\n"
+ f"user\n{MESSAGES[0]['content']}\n"
+ f"model\n{MESSAGES[1]['content']}\n"
+ f"user\n{MESSAGES[2]['content']}\n"
"model\n"
)
- answer_str = "很高兴认识你!\n"
+ answer_str = f"{MESSAGES[3]['content']}\n"
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
@@ -206,12 +227,12 @@ def test_gemma_template(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool):
prompt_str = (
- "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
- "<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
- "<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
+ f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
+ f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
+ f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
- answer_str = "很高兴认识你!<|eot_id|>"
+ answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
@@ -220,12 +241,12 @@ def test_llama3_template(use_fast: bool):
)
def test_llama4_template(use_fast: bool):
prompt_str = (
- "<|begin_of_text|><|header_start|>user<|header_end|>\n\nHow are you<|eot|>"
- "<|header_start|>assistant<|header_end|>\n\nI am fine!<|eot|>"
- "<|header_start|>user<|header_end|>\n\n你好<|eot|>"
+ f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
+ f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
+ f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
"<|header_start|>assistant<|header_end|>\n\n"
)
- answer_str = "很高兴认识你!<|eot|>"
+ answer_str = f"{MESSAGES[3]['content']}<|eot|>"
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
@@ -234,12 +255,12 @@ def test_llama4_template(use_fast: bool):
)
def test_phi4_template(use_fast: bool):
prompt_str = (
- "<|im_start|>user<|im_sep|>How are you<|im_end|>"
- "<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>"
- "<|im_start|>user<|im_sep|>你好<|im_end|>"
+ f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
+ f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
+ f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
"<|im_start|>assistant<|im_sep|>"
)
- answer_str = "很高兴认识你!<|im_end|>"
+ answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@@ -247,34 +268,30 @@ def test_phi4_template(use_fast: bool):
def test_qwen2_5_template(use_fast: bool):
prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
- "<|im_start|>user\nHow are you<|im_end|>\n"
- "<|im_start|>assistant\nI am fine!<|im_end|>\n"
- "<|im_start|>user\n你好<|im_end|>\n"
+ f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
+ f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
+ f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n"
)
- answer_str = "很高兴认识你!<|im_end|>\n"
+ answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize("use_fast", [True, False])
-def test_qwen3_template(use_fast: bool):
+@pytest.mark.parametrize("cot_messages", [True, False])
+def test_qwen3_template(use_fast: bool, cot_messages: bool):
+ messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
prompt_str = (
- "<|im_start|>user\nHow are you<|im_end|>\n"
- "<|im_start|>assistant\nI am fine!<|im_end|>\n"
- "<|im_start|>user\n你好<|im_end|>\n"
- "<|im_start|>assistant\n\n\n \n\n"
- )
- answer_str = "很高兴认识你!<|im_end|>\n"
- _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast)
-
- prompt_str = (
- "<|im_start|>user\nHow are you<|im_end|>\n"
- "<|im_start|>assistant\nI am fine!<|im_end|>\n"
- "<|im_start|>user\n你好<|im_end|>\n"
+ f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n"
+ f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
+ f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n"
)
- answer_str = "\n模型思考内容\n \n\n很高兴认识你!<|im_end|>\n"
- _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT)
+ answer_str = f"{messages[3]['content']}<|im_end|>\n"
+ if not cot_messages:
+ answer_str = "\n\n \n\n" + answer_str
+
+ _check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
def test_parse_llama3_template():
@@ -293,6 +310,7 @@ def test_parse_llama3_template():
def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
template = parse_template(tokenizer)
+ assert template.__class__.__name__ == "Template"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
@@ -303,6 +321,7 @@ def test_parse_qwen_template():
def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
template = parse_template(tokenizer)
+ assert template.__class__.__name__ == "ReasoningTemplate"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]