[data] qwen3 fixes (#8109)

This commit is contained in:
hoshi-hiyouga 2025-05-20 02:00:30 +08:00 committed by GitHub
parent f3fd67a9bb
commit b83a38eb98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 197 additions and 160 deletions

View File

@ -250,7 +250,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
</details> </details>
> [!NOTE] > [!TIP]
> If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again. > If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again.
## Supported Models ## Supported Models

View File

@ -237,7 +237,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
</details> </details>
> [!NOTE] > [!TIP]
> 如果您无法使用最新的功能,请尝试重新拉取代码并再次安装 LLaMA-Factory。 > 如果您无法使用最新的功能,请尝试重新拉取代码并再次安装 LLaMA-Factory。
## 模型 ## 模型

View File

@ -50,7 +50,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
* [Example dataset](alpaca_en_demo.json) * [Example dataset](alpaca_en_demo.json)
In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the human prompt, then the human prompt would be `instruction\ninput`. The `output` column represents the model response. In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the user prompt, then the user prompt would be `instruction\ninput`. The `output` column represents the model response.
For reasoning models, if the dataset contains chain-of-thought (CoT), the CoT needs to be placed in the model responses, such as `<think>cot</think>output`.
The `system` column will be used as the system prompt if specified. The `system` column will be used as the system prompt if specified.
@ -59,13 +61,13 @@ The `history` column is a list consisting of string tuples representing prompt-r
```json ```json
[ [
{ {
"instruction": "human instruction (required)", "instruction": "user instruction (required)",
"input": "human input (optional)", "input": "user input (optional)",
"output": "model response (required)", "output": "model response (required)",
"system": "system prompt (optional)", "system": "system prompt (optional)",
"history": [ "history": [
["human instruction in the first round (optional)", "model response in the first round (optional)"], ["user instruction in the first round (optional)", "model response in the first round (optional)"],
["human instruction in the second round (optional)", "model response in the second round (optional)"] ["user instruction in the second round (optional)", "model response in the second round (optional)"]
] ]
} }
] ]
@ -86,6 +88,9 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
} }
``` ```
> [!TIP]
> If the model has reasoning capabilities but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True`, the empty CoT will be added to the model responses and loss computation will be considered; otherwise, it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference.
### Pre-training Dataset ### Pre-training Dataset
- [Example dataset](c4_demo.jsonl) - [Example dataset](c4_demo.jsonl)
@ -119,8 +124,8 @@ It requires a better response in `chosen` column and a worse response in `reject
```json ```json
[ [
{ {
"instruction": "human instruction (required)", "instruction": "user instruction (required)",
"input": "human input (optional)", "input": "user input (optional)",
"chosen": "chosen answer (required)", "chosen": "chosen answer (required)",
"rejected": "rejected answer (required)" "rejected": "rejected answer (required)"
} }
@ -174,7 +179,7 @@ Note that the human and observation should appear in odd positions, while gpt an
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "function_call", "from": "function_call",
@ -225,7 +230,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -233,7 +238,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
}, },
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
} }
], ],
"chosen": { "chosen": {
@ -275,7 +280,7 @@ KTO datasets require a extra `kto_tag` column containing the boolean human feedb
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "human instruction" "value": "user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -314,7 +319,7 @@ The number of images should be identical to the `<image>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<image>human instruction" "value": "<image>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -355,7 +360,7 @@ The number of videos should be identical to the `<video>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<video>human instruction" "value": "<video>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -396,7 +401,7 @@ The number of audios should be identical to the `<audio>` tokens in the conversa
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<audio>human instruction" "value": "<audio>user instruction"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -437,7 +442,7 @@ The openai format is simply a special case of the sharegpt format, where the fir
}, },
{ {
"role": "user", "role": "user",
"content": "human instruction" "content": "user instruction"
}, },
{ {
"role": "assistant", "role": "assistant",

View File

@ -49,7 +49,9 @@
- [样例数据集](alpaca_zh_demo.json) - [样例数据集](alpaca_zh_demo.json)
在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为人类指令,即人类指令为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。 在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为提示词,即提示词为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
对于推理类模型的微调,如果数据集包含思维链,则需要把思维链放在模型回答中,例如 `<think>cot</think>output`
如果指定,`system` 列对应的内容将被作为系统提示词。 如果指定,`system` 列对应的内容将被作为系统提示词。
@ -58,8 +60,8 @@
```json ```json
[ [
{ {
"instruction": "人类指令(必填)", "instruction": "用户指令(必填)",
"input": "人类输入(选填)", "input": "用户输入(选填)",
"output": "模型回答(必填)", "output": "模型回答(必填)",
"system": "系统提示词(选填)", "system": "系统提示词(选填)",
"history": [ "history": [
@ -85,6 +87,9 @@
} }
``` ```
> [!TIP]
> 如果模型本身具备推理能力而数据集不包含思维链LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking``True` 时,空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失。请在训练和推理时保持 `enable_thinking` 参数一致。
### 预训练数据集 ### 预训练数据集
- [样例数据集](c4_demo.jsonl) - [样例数据集](c4_demo.jsonl)
@ -118,8 +123,8 @@
```json ```json
[ [
{ {
"instruction": "人类指令(必填)", "instruction": "用户指令(必填)",
"input": "人类输入(选填)", "input": "用户输入(选填)",
"chosen": "优质回答(必填)", "chosen": "优质回答(必填)",
"rejected": "劣质回答(必填)" "rejected": "劣质回答(必填)"
} }
@ -173,7 +178,7 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "function_call", "from": "function_call",
@ -224,7 +229,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -232,7 +237,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
}, },
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
} }
], ],
"chosen": { "chosen": {
@ -274,7 +279,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "人类指令" "value": "用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -313,7 +318,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<image>人类指令" "value": "<image><image>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -321,6 +326,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"images": [ "images": [
"图像路径(必填)",
"图像路径(必填)" "图像路径(必填)"
] ]
} }
@ -354,7 +360,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<video>人类指令" "value": "<video><video>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -362,6 +368,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"videos": [ "videos": [
"视频路径(必填)",
"视频路径(必填)" "视频路径(必填)"
] ]
} }
@ -395,7 +402,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
"conversations": [ "conversations": [
{ {
"from": "human", "from": "human",
"value": "<audio>人类指令" "value": "<audio><audio>用户指令"
}, },
{ {
"from": "gpt", "from": "gpt",
@ -403,6 +410,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
} }
], ],
"audios": [ "audios": [
"音频路径(必填)",
"音频路径(必填)" "音频路径(必填)"
] ]
} }
@ -437,7 +445,7 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
}, },
{ {
"role": "user", "role": "user",
"content": "人类指令" "content": "用户指令"
}, },
{ {
"role": "assistant", "role": "assistant",

View File

@ -49,6 +49,8 @@ def vllm_infer(
max_new_tokens: int = 1024, max_new_tokens: int = 1024,
repetition_penalty: float = 1.0, repetition_penalty: float = 1.0,
skip_special_tokens: bool = True, skip_special_tokens: bool = True,
default_system: Optional[str] = None,
enable_thinking: bool = True,
seed: Optional[int] = None, seed: Optional[int] = None,
pipeline_parallel_size: int = 1, pipeline_parallel_size: int = 1,
image_max_pixels: int = 768 * 768, image_max_pixels: int = 768 * 768,
@ -74,6 +76,8 @@ def vllm_infer(
cutoff_len=cutoff_len, cutoff_len=cutoff_len,
max_samples=max_samples, max_samples=max_samples,
preprocessing_num_workers=16, preprocessing_num_workers=16,
default_system=default_system,
enable_thinking=enable_thinking,
vllm_config=vllm_config, vllm_config=vllm_config,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
@ -127,14 +131,11 @@ def vllm_infer(
lora_request = None lora_request = None
# Store all results in these lists # Store all results in these lists
all_prompts = [] all_prompts, all_preds, all_labels = [], [], []
all_preds = []
all_labels = []
# Add batch process to avoid the issue of too many files opened # Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"): for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
vllm_inputs, prompts, labels = [], [], [] vllm_inputs, prompts, labels = [], [], []
batch = train_dataset[i : min(i + batch_size, len(train_dataset))] batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])): for j in range(len(batch["input_ids"])):
@ -176,15 +177,14 @@ def vllm_infer(
) )
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request) results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
preds = [result.outputs[0].text for result in results] preds = [result.outputs[0].text for result in results]
# Accumulate results # Accumulate results
all_prompts.extend(prompts) all_prompts.extend(prompts)
all_preds.extend(preds) all_preds.extend(preds)
all_labels.extend(labels) all_labels.extend(labels)
gc.collect() gc.collect()
# Write all results at once outside the loop # Write all results at once outside the loop
with open(save_name, "w", encoding="utf-8") as f: with open(save_name, "w", encoding="utf-8") as f:
for text, pred, label in zip(all_prompts, all_preds, all_labels): for text, pred, label in zip(all_prompts, all_preds, all_labels):

View File

@ -104,10 +104,7 @@ class HuggingfaceEngine(BaseEngine):
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"] prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
enable_thinking = input_kwargs.pop("enable_thinking", None)
enable_thinking = enable_thinking if enable_thinking is not None else generating_args["enable_thinking"]
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools, enable_thinking)
prompt_ids, _ = template.mm_plugin.process_token_ids( prompt_ids, _ = template.mm_plugin.process_token_ids(
prompt_ids, prompt_ids,
None, None,

View File

@ -160,10 +160,7 @@ class SGLangEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
enable_thinking = input_kwargs.pop("enable_thinking", None)
enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None) temperature: Optional[float] = input_kwargs.pop("temperature", None)

View File

@ -124,10 +124,7 @@ class VllmEngine(BaseEngine):
messages, images or [], videos or [], audios or [], self.processor messages, images or [], videos or [], audios or [], self.processor
) )
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
enable_thinking = input_kwargs.pop("enable_thinking", None)
enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"]
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
prompt_length = len(prompt_ids) prompt_length = len(prompt_ids)
temperature: Optional[float] = input_kwargs.pop("temperature", None) temperature: Optional[float] = input_kwargs.pop("temperature", None)

View File

@ -52,6 +52,7 @@ class Template:
efficient_eos: bool efficient_eos: bool
replace_eos: bool replace_eos: bool
replace_jinja_template: bool replace_jinja_template: bool
enable_thinking: bool
mm_plugin: "BasePlugin" mm_plugin: "BasePlugin"
def encode_oneturn( def encode_oneturn(
@ -60,7 +61,6 @@ class Template:
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
enable_thinking: bool = False,
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
r"""Return a single pair of token ids representing prompt and response respectively.""" r"""Return a single pair of token ids representing prompt and response respectively."""
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
@ -94,7 +94,7 @@ class Template:
return list(stop_token_ids) return list(stop_token_ids)
def add_thought(self, content: str) -> str: def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message.""" r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
@ -105,7 +105,7 @@ class Template:
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
r"""Get the token ids of thought words.""" r"""Get the token ids of thought words."""
return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False) return tokenizer.encode(self.add_thought(), add_special_tokens=False)
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]: def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
r"""Convert elements to token ids.""" r"""Convert elements to token ids."""
@ -406,26 +406,21 @@ class ReasoningTemplate(Template):
messages: list[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
enable_thinking: bool = False,
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
messages = deepcopy(messages) messages = deepcopy(messages)
for i in range(len(messages)): for i in range(1, len(messages) - 2, 2):
if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1):
messages[i]["content"] = self.remove_thought(messages[i]["content"]) messages[i]["content"] = self.remove_thought(messages[i]["content"])
encoded_messages = self._encode(tokenizer, messages, system, tools) prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
prompt_ids = [] if (
for encoded_ids in encoded_messages[:-1]: self.thought_words[0] not in messages[-1]["content"]
prompt_ids += encoded_ids
if not enable_thinking and (
messages[-1]["role"] == Role.ASSISTANT
and self.thought_words[0] not in messages[-1]["content"]
and self.thought_words[1] not in messages[-1]["content"] and self.thought_words[1] not in messages[-1]["content"]
): ):
prompt_ids += self.get_thought_word_ids(tokenizer) if not self.enable_thinking:
prompt_ids = prompt_ids + self.get_thought_word_ids(tokenizer)
else:
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
response_ids = encoded_messages[-1]
return prompt_ids, response_ids return prompt_ids, response_ids
@override @override
@ -436,15 +431,16 @@ class ReasoningTemplate(Template):
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]: ) -> list[tuple[list[int], list[int]]]:
messages = deepcopy(messages)
encoded_messages = self._encode(tokenizer, messages, system, tools) encoded_messages = self._encode(tokenizer, messages, system, tools)
for i in range(len(messages) - 1): for i in range(0, len(messages), 2):
if ( if (
messages[i + 1]["role"] == Role.ASSISTANT self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[0] not in messages[i + 1]["content"]
and self.thought_words[1] not in messages[i + 1]["content"] and self.thought_words[1] not in messages[i + 1]["content"]
): ):
if not self.enable_thinking:
encoded_messages[i] += self.get_thought_word_ids(tokenizer) encoded_messages[i] += self.get_thought_word_ids(tokenizer)
else:
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
@ -467,6 +463,7 @@ def register_template(
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,
replace_jinja_template: bool = False, replace_jinja_template: bool = False,
enable_thinking: bool = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template, template_class: type["Template"] = Template,
) -> None: ) -> None:
@ -513,6 +510,7 @@ def register_template(
efficient_eos=efficient_eos, efficient_eos=efficient_eos,
replace_eos=replace_eos, replace_eos=replace_eos,
replace_jinja_template=replace_jinja_template, replace_jinja_template=replace_jinja_template,
enable_thinking=enable_thinking,
mm_plugin=mm_plugin, mm_plugin=mm_plugin,
) )
@ -549,6 +547,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}] messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :] assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
if len(user_slot) > len(user_slot_empty_system): if len(user_slot) > len(user_slot_empty_system):
@ -558,7 +557,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
default_system = "" default_system = ""
return Template( return template_class(
format_user=StringFormatter(slots=[user_slot]), format_user=StringFormatter(slots=[user_slot]),
format_assistant=StringFormatter(slots=[assistant_slot]), format_assistant=StringFormatter(slots=[assistant_slot]),
format_system=StringFormatter(slots=[system_slot]), format_system=StringFormatter(slots=[system_slot]),
@ -572,6 +571,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
efficient_eos=False, efficient_eos=False,
replace_eos=False, replace_eos=False,
replace_jinja_template=False, replace_jinja_template=False,
enable_thinking=True,
mm_plugin=get_mm_plugin(name="base"), mm_plugin=get_mm_plugin(name="base"),
) )
@ -600,6 +600,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format) template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
template.format_tools = ToolFormatter(tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
if data_args.default_system is not None:
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
template.default_system = data_args.default_system
template.enable_thinking = data_args.enable_thinking
template.fix_special_tokens(tokenizer) template.fix_special_tokens(tokenizer)
template.fix_jinja_template(tokenizer) template.fix_jinja_template(tokenizer)
return template return template

View File

@ -115,6 +115,14 @@ class DataArguments:
default=None, default=None,
metadata={"help": "Tool format to use for constructing function calling examples."}, metadata={"help": "Tool format to use for constructing function calling examples."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Override the default system message in the template."},
)
enable_thinking: bool = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
tokenized_path: Optional[str] = field( tokenized_path: Optional[str] = field(
default=None, default=None,
metadata={ metadata={

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Optional from typing import Any
from transformers import GenerationConfig from transformers import GenerationConfig
@ -62,18 +62,10 @@ class GeneratingArguments:
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
skip_special_tokens: bool = field( skip_special_tokens: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."}, metadata={"help": "Whether or not to remove special tokens in the decoding."},
) )
enable_thinking: bool = field(
default=True,
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
)
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]: def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
args = asdict(self) args = asdict(self)

View File

@ -15,6 +15,7 @@
import json import json
import os import os
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from transformers.utils import is_torch_npu_available from transformers.utils import is_torch_npu_available
@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
) )
@contextmanager
def update_attr(obj: Any, name: str, value: Any):
old_value = getattr(obj, name, None)
setattr(obj, name, value)
yield
setattr(obj, name, old_value)
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
self.manager = manager self.manager = manager
@ -198,6 +207,7 @@ class WebChatModel(ChatModel):
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ... Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
Output: infer.chatbot, infer.messages Output: infer.chatbot, infer.messages
""" """
with update_attr(self.engine.template, "enable_thinking", enable_thinking):
chatbot.append({"role": "assistant", "content": ""}) chatbot.append({"role": "assistant", "content": ""})
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
@ -211,7 +221,6 @@ class WebChatModel(ChatModel):
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
skip_special_tokens=skip_special_tokens, skip_special_tokens=skip_special_tokens,
enable_thinking=enable_thinking,
): ):
response += new_text response += new_text
if tools: if tools:

View File

@ -126,29 +126,50 @@ def test_encode_multiturn(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_reasoning_encode_oneturn(use_fast: bool): @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False])
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3")) data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES) template = get_template_and_fix_tokenizer(tokenizer, data_args)
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
prompt_str = ( prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n" f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n" f"{MESSAGES[1]['content']}<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n" f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
) )
answer_str = "很高兴认识你!<|im_end|>\n" answer_str = f"{messages[3]['content']}<|im_end|>\n"
if not cot_messages:
if enable_thinking:
answer_str = "<think>\n\n</think>\n\n" + answer_str
else:
prompt_str = prompt_str + "<think>\n\n</think>\n\n"
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str)) _check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_reasoning_encode_multiturn(use_fast: bool): @pytest.mark.parametrize("cot_messages", [True, False])
@pytest.mark.parametrize("enable_thinking", [True, False])
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3")) data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES) template = get_template_and_fix_tokenizer(tokenizer, data_args)
prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" encoded_pairs = template.encode_multiturn(tokenizer, messages)
answer_str_1 = "I am fine!<|im_end|>\n" prompt_str_1 = f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
answer_str_2 = "很高兴认识你!<|im_end|>\n" prompt_str_2 = f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
if not cot_messages:
if enable_thinking:
answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
else:
prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
_check_tokenization( _check_tokenization(
tokenizer, tokenizer,
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]), (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
@ -193,12 +214,12 @@ def test_get_stop_token_ids():
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_gemma_template(use_fast: bool): def test_gemma_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n" f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
"<start_of_turn>model\nI am fine!<end_of_turn>\n" f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
"<start_of_turn>user\n你好<end_of_turn>\n" f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
"<start_of_turn>model\n" "<start_of_turn>model\n"
) )
answer_str = "很高兴认识你!<end_of_turn>\n" answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast) _check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
@ -206,12 +227,12 @@ def test_gemma_template(use_fast: bool):
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_llama3_template(use_fast: bool): def test_llama3_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>" f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>" f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>" f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n" "<|start_header_id|>assistant<|end_header_id|>\n\n"
) )
answer_str = "很高兴认识你!<|eot_id|>" answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast) _check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
@ -220,12 +241,12 @@ def test_llama3_template(use_fast: bool):
) )
def test_llama4_template(use_fast: bool): def test_llama4_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|begin_of_text|><|header_start|>user<|header_end|>\n\nHow are you<|eot|>" f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
"<|header_start|>assistant<|header_end|>\n\nI am fine!<|eot|>" f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
"<|header_start|>user<|header_end|>\n\n你好<|eot|>" f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
"<|header_start|>assistant<|header_end|>\n\n" "<|header_start|>assistant<|header_end|>\n\n"
) )
answer_str = "很高兴认识你!<|eot|>" answer_str = f"{MESSAGES[3]['content']}<|eot|>"
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast) _check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
@ -234,12 +255,12 @@ def test_llama4_template(use_fast: bool):
) )
def test_phi4_template(use_fast: bool): def test_phi4_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>user<|im_sep|>How are you<|im_end|>" f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>" f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
"<|im_start|>user<|im_sep|>你好<|im_end|>" f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
"<|im_start|>assistant<|im_sep|>" "<|im_start|>assistant<|im_sep|>"
) )
answer_str = "很高兴认识你!<|im_end|>" answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast) _check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
@ -247,34 +268,30 @@ def test_phi4_template(use_fast: bool):
def test_qwen2_5_template(use_fast: bool): def test_qwen2_5_template(use_fast: bool):
prompt_str = ( prompt_str = (
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n" "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\nHow are you<|im_end|>\n" f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n" f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n" f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = "很高兴认识你!<|im_end|>\n" answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast) _check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
@pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("use_fast", [True, False])
def test_qwen3_template(use_fast: bool): @pytest.mark.parametrize("cot_messages", [True, False])
def test_qwen3_template(use_fast: bool, cot_messages: bool):
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
prompt_str = ( prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n" f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n" f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n" f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n"
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
)
answer_str = "很高兴认识你!<|im_end|>\n"
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast)
prompt_str = (
"<|im_start|>user\nHow are you<|im_end|>\n"
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
"<|im_start|>user\n你好<|im_end|>\n"
"<|im_start|>assistant\n" "<|im_start|>assistant\n"
) )
answer_str = "<think>\n模型思考内容\n</think>\n\n很高兴认识你!<|im_end|>\n" answer_str = f"{messages[3]['content']}<|im_end|>\n"
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT) if not cot_messages:
answer_str = "<think>\n\n</think>\n\n" + answer_str
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
def test_parse_llama3_template(): def test_parse_llama3_template():
@ -293,6 +310,7 @@ def test_parse_llama3_template():
def test_parse_qwen_template(): def test_parse_qwen_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.__class__.__name__ == "Template"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"] assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"] assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
@ -303,6 +321,7 @@ def test_parse_qwen_template():
def test_parse_qwen3_template(): def test_parse_qwen3_template():
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
template = parse_template(tokenizer) template = parse_template(tokenizer)
assert template.__class__.__name__ == "ReasoningTemplate"
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"] assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"] assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"] assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]