mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[data] qwen3 fixes (#8109)
This commit is contained in:
parent
f3fd67a9bb
commit
b83a38eb98
@ -250,7 +250,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
> [!NOTE]
|
> [!TIP]
|
||||||
> If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again.
|
> If you cannot use the latest feature, please pull the latest code and install LLaMA-Factory again.
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
@ -237,7 +237,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
> [!NOTE]
|
> [!TIP]
|
||||||
> 如果您无法使用最新的功能,请尝试重新拉取代码并再次安装 LLaMA-Factory。
|
> 如果您无法使用最新的功能,请尝试重新拉取代码并再次安装 LLaMA-Factory。
|
||||||
|
|
||||||
## 模型
|
## 模型
|
||||||
|
@ -50,7 +50,9 @@ Currently we support datasets in **alpaca** and **sharegpt** format.
|
|||||||
|
|
||||||
* [Example dataset](alpaca_en_demo.json)
|
* [Example dataset](alpaca_en_demo.json)
|
||||||
|
|
||||||
In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the human prompt, then the human prompt would be `instruction\ninput`. The `output` column represents the model response.
|
In supervised fine-tuning, the `instruction` column will be concatenated with the `input` column and used as the user prompt, then the user prompt would be `instruction\ninput`. The `output` column represents the model response.
|
||||||
|
|
||||||
|
For reasoning models, if the dataset contains chain-of-thought (CoT), the CoT needs to be placed in the model responses, such as `<think>cot</think>output`.
|
||||||
|
|
||||||
The `system` column will be used as the system prompt if specified.
|
The `system` column will be used as the system prompt if specified.
|
||||||
|
|
||||||
@ -59,13 +61,13 @@ The `history` column is a list consisting of string tuples representing prompt-r
|
|||||||
```json
|
```json
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"instruction": "human instruction (required)",
|
"instruction": "user instruction (required)",
|
||||||
"input": "human input (optional)",
|
"input": "user input (optional)",
|
||||||
"output": "model response (required)",
|
"output": "model response (required)",
|
||||||
"system": "system prompt (optional)",
|
"system": "system prompt (optional)",
|
||||||
"history": [
|
"history": [
|
||||||
["human instruction in the first round (optional)", "model response in the first round (optional)"],
|
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||||
["human instruction in the second round (optional)", "model response in the second round (optional)"]
|
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@ -86,6 +88,9 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> If the model has reasoning capabilities but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True`, the empty CoT will be added to the model responses and loss computation will be considered; otherwise, it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference.
|
||||||
|
|
||||||
### Pre-training Dataset
|
### Pre-training Dataset
|
||||||
|
|
||||||
- [Example dataset](c4_demo.jsonl)
|
- [Example dataset](c4_demo.jsonl)
|
||||||
@ -119,8 +124,8 @@ It requires a better response in `chosen` column and a worse response in `reject
|
|||||||
```json
|
```json
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"instruction": "human instruction (required)",
|
"instruction": "user instruction (required)",
|
||||||
"input": "human input (optional)",
|
"input": "user input (optional)",
|
||||||
"chosen": "chosen answer (required)",
|
"chosen": "chosen answer (required)",
|
||||||
"rejected": "rejected answer (required)"
|
"rejected": "rejected answer (required)"
|
||||||
}
|
}
|
||||||
@ -174,7 +179,7 @@ Note that the human and observation should appear in odd positions, while gpt an
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "human instruction"
|
"value": "user instruction"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "function_call",
|
"from": "function_call",
|
||||||
@ -225,7 +230,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "human instruction"
|
"value": "user instruction"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -233,7 +238,7 @@ Preference datasets in sharegpt format also require a better message in `chosen`
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "human instruction"
|
"value": "user instruction"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"chosen": {
|
"chosen": {
|
||||||
@ -275,7 +280,7 @@ KTO datasets require a extra `kto_tag` column containing the boolean human feedb
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "human instruction"
|
"value": "user instruction"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -314,7 +319,7 @@ The number of images should be identical to the `<image>` tokens in the conversa
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "<image>human instruction"
|
"value": "<image>user instruction"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -355,7 +360,7 @@ The number of videos should be identical to the `<video>` tokens in the conversa
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "<video>human instruction"
|
"value": "<video>user instruction"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -396,7 +401,7 @@ The number of audios should be identical to the `<audio>` tokens in the conversa
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "<audio>human instruction"
|
"value": "<audio>user instruction"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -437,7 +442,7 @@ The openai format is simply a special case of the sharegpt format, where the fir
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "human instruction"
|
"content": "user instruction"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
@ -49,7 +49,9 @@
|
|||||||
|
|
||||||
- [样例数据集](alpaca_zh_demo.json)
|
- [样例数据集](alpaca_zh_demo.json)
|
||||||
|
|
||||||
在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为人类指令,即人类指令为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
|
在指令监督微调时,`instruction` 列对应的内容会与 `input` 列对应的内容拼接后作为提示词,即提示词为 `instruction\ninput`。而 `output` 列对应的内容为模型回答。
|
||||||
|
|
||||||
|
对于推理类模型的微调,如果数据集包含思维链,则需要把思维链放在模型回答中,例如 `<think>cot</think>output`。
|
||||||
|
|
||||||
如果指定,`system` 列对应的内容将被作为系统提示词。
|
如果指定,`system` 列对应的内容将被作为系统提示词。
|
||||||
|
|
||||||
@ -58,8 +60,8 @@
|
|||||||
```json
|
```json
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"instruction": "人类指令(必填)",
|
"instruction": "用户指令(必填)",
|
||||||
"input": "人类输入(选填)",
|
"input": "用户输入(选填)",
|
||||||
"output": "模型回答(必填)",
|
"output": "模型回答(必填)",
|
||||||
"system": "系统提示词(选填)",
|
"system": "系统提示词(选填)",
|
||||||
"history": [
|
"history": [
|
||||||
@ -85,6 +87,9 @@
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> 如果模型本身具备推理能力,而数据集不包含思维链,LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking` 为 `True` 时,空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失。请在训练和推理时保持 `enable_thinking` 参数一致。
|
||||||
|
|
||||||
### 预训练数据集
|
### 预训练数据集
|
||||||
|
|
||||||
- [样例数据集](c4_demo.jsonl)
|
- [样例数据集](c4_demo.jsonl)
|
||||||
@ -118,8 +123,8 @@
|
|||||||
```json
|
```json
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
"instruction": "人类指令(必填)",
|
"instruction": "用户指令(必填)",
|
||||||
"input": "人类输入(选填)",
|
"input": "用户输入(选填)",
|
||||||
"chosen": "优质回答(必填)",
|
"chosen": "优质回答(必填)",
|
||||||
"rejected": "劣质回答(必填)"
|
"rejected": "劣质回答(必填)"
|
||||||
}
|
}
|
||||||
@ -173,7 +178,7 @@ KTO 数据集需要提供额外的 `kto_tag` 列。详情请参阅 [sharegpt](#s
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "人类指令"
|
"value": "用户指令"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "function_call",
|
"from": "function_call",
|
||||||
@ -224,7 +229,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "人类指令"
|
"value": "用户指令"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -232,7 +237,7 @@ Sharegpt 格式的偏好数据集同样需要在 `chosen` 列中提供更优的
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "人类指令"
|
"value": "用户指令"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"chosen": {
|
"chosen": {
|
||||||
@ -274,7 +279,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "人类指令"
|
"value": "用户指令"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -313,7 +318,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "<image>人类指令"
|
"value": "<image><image>用户指令"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -321,6 +326,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"images": [
|
"images": [
|
||||||
|
"图像路径(必填)",
|
||||||
"图像路径(必填)"
|
"图像路径(必填)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -354,7 +360,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "<video>人类指令"
|
"value": "<video><video>用户指令"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -362,6 +368,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"videos": [
|
"videos": [
|
||||||
|
"视频路径(必填)",
|
||||||
"视频路径(必填)"
|
"视频路径(必填)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -395,7 +402,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
"conversations": [
|
"conversations": [
|
||||||
{
|
{
|
||||||
"from": "human",
|
"from": "human",
|
||||||
"value": "<audio>人类指令"
|
"value": "<audio><audio>用户指令"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"from": "gpt",
|
"from": "gpt",
|
||||||
@ -403,6 +410,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"audios": [
|
"audios": [
|
||||||
|
"音频路径(必填)",
|
||||||
"音频路径(必填)"
|
"音频路径(必填)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@ -437,7 +445,7 @@ OpenAI 格式仅仅是 sharegpt 格式的一种特殊情况,其中第一条消
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "人类指令"
|
"content": "用户指令"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
|
@ -49,6 +49,8 @@ def vllm_infer(
|
|||||||
max_new_tokens: int = 1024,
|
max_new_tokens: int = 1024,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
|
default_system: Optional[str] = None,
|
||||||
|
enable_thinking: bool = True,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
image_max_pixels: int = 768 * 768,
|
image_max_pixels: int = 768 * 768,
|
||||||
@ -74,6 +76,8 @@ def vllm_infer(
|
|||||||
cutoff_len=cutoff_len,
|
cutoff_len=cutoff_len,
|
||||||
max_samples=max_samples,
|
max_samples=max_samples,
|
||||||
preprocessing_num_workers=16,
|
preprocessing_num_workers=16,
|
||||||
|
default_system=default_system,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
@ -127,14 +131,11 @@ def vllm_infer(
|
|||||||
lora_request = None
|
lora_request = None
|
||||||
|
|
||||||
# Store all results in these lists
|
# Store all results in these lists
|
||||||
all_prompts = []
|
all_prompts, all_preds, all_labels = [], [], []
|
||||||
all_preds = []
|
|
||||||
all_labels = []
|
|
||||||
|
|
||||||
# Add batch process to avoid the issue of too many files opened
|
# Add batch process to avoid the issue of too many files opened
|
||||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||||
vllm_inputs, prompts, labels = [], [], []
|
vllm_inputs, prompts, labels = [], [], []
|
||||||
|
|
||||||
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
||||||
|
|
||||||
for j in range(len(batch["input_ids"])):
|
for j in range(len(batch["input_ids"])):
|
||||||
@ -176,15 +177,14 @@ def vllm_infer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
|
results = llm.generate(vllm_inputs, sampling_params, lora_request=lora_request)
|
||||||
|
|
||||||
preds = [result.outputs[0].text for result in results]
|
preds = [result.outputs[0].text for result in results]
|
||||||
|
|
||||||
# Accumulate results
|
# Accumulate results
|
||||||
all_prompts.extend(prompts)
|
all_prompts.extend(prompts)
|
||||||
all_preds.extend(preds)
|
all_preds.extend(preds)
|
||||||
all_labels.extend(labels)
|
all_labels.extend(labels)
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
# Write all results at once outside the loop
|
# Write all results at once outside the loop
|
||||||
with open(save_name, "w", encoding="utf-8") as f:
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
for text, pred, label in zip(all_prompts, all_preds, all_labels):
|
for text, pred, label in zip(all_prompts, all_preds, all_labels):
|
||||||
|
@ -104,10 +104,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
|
messages, mm_input_dict["images"], mm_input_dict["videos"], mm_input_dict["audios"], processor
|
||||||
)
|
)
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or generating_args["default_system"]
|
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools)
|
||||||
enable_thinking = input_kwargs.pop("enable_thinking", None)
|
|
||||||
enable_thinking = enable_thinking if enable_thinking is not None else generating_args["enable_thinking"]
|
|
||||||
prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools, enable_thinking)
|
|
||||||
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
prompt_ids, _ = template.mm_plugin.process_token_ids(
|
||||||
prompt_ids,
|
prompt_ids,
|
||||||
None,
|
None,
|
||||||
|
@ -160,10 +160,7 @@ class SGLangEngine(BaseEngine):
|
|||||||
messages, images or [], videos or [], audios or [], self.processor
|
messages, images or [], videos or [], audios or [], self.processor
|
||||||
)
|
)
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or self.generating_args["default_system"]
|
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||||
enable_thinking = input_kwargs.pop("enable_thinking", None)
|
|
||||||
enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"]
|
|
||||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
|
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
|
|
||||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||||
|
@ -124,10 +124,7 @@ class VllmEngine(BaseEngine):
|
|||||||
messages, images or [], videos or [], audios or [], self.processor
|
messages, images or [], videos or [], audios or [], self.processor
|
||||||
)
|
)
|
||||||
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
paired_messages = messages + [{"role": "assistant", "content": ""}]
|
||||||
system = system or self.generating_args["default_system"]
|
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools)
|
||||||
enable_thinking = input_kwargs.pop("enable_thinking", None)
|
|
||||||
enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"]
|
|
||||||
prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking)
|
|
||||||
prompt_length = len(prompt_ids)
|
prompt_length = len(prompt_ids)
|
||||||
|
|
||||||
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
temperature: Optional[float] = input_kwargs.pop("temperature", None)
|
||||||
|
@ -52,6 +52,7 @@ class Template:
|
|||||||
efficient_eos: bool
|
efficient_eos: bool
|
||||||
replace_eos: bool
|
replace_eos: bool
|
||||||
replace_jinja_template: bool
|
replace_jinja_template: bool
|
||||||
|
enable_thinking: bool
|
||||||
mm_plugin: "BasePlugin"
|
mm_plugin: "BasePlugin"
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
@ -60,7 +61,6 @@ class Template:
|
|||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
enable_thinking: bool = False,
|
|
||||||
) -> tuple[list[int], list[int]]:
|
) -> tuple[list[int], list[int]]:
|
||||||
r"""Return a single pair of token ids representing prompt and response respectively."""
|
r"""Return a single pair of token ids representing prompt and response respectively."""
|
||||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
@ -94,7 +94,7 @@ class Template:
|
|||||||
|
|
||||||
return list(stop_token_ids)
|
return list(stop_token_ids)
|
||||||
|
|
||||||
def add_thought(self, content: str) -> str:
|
def add_thought(self, content: str = "") -> str:
|
||||||
r"""Add empty thought to assistant message."""
|
r"""Add empty thought to assistant message."""
|
||||||
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
|
return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
|
||||||
|
|
||||||
@ -105,7 +105,7 @@ class Template:
|
|||||||
|
|
||||||
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
|
def get_thought_word_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
|
||||||
r"""Get the token ids of thought words."""
|
r"""Get the token ids of thought words."""
|
||||||
return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False)
|
return tokenizer.encode(self.add_thought(), add_special_tokens=False)
|
||||||
|
|
||||||
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
|
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> list[int]:
|
||||||
r"""Convert elements to token ids."""
|
r"""Convert elements to token ids."""
|
||||||
@ -406,26 +406,21 @@ class ReasoningTemplate(Template):
|
|||||||
messages: list[dict[str, str]],
|
messages: list[dict[str, str]],
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
enable_thinking: bool = False,
|
|
||||||
) -> tuple[list[int], list[int]]:
|
) -> tuple[list[int], list[int]]:
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
for i in range(len(messages)):
|
for i in range(1, len(messages) - 2, 2):
|
||||||
if messages[i]["role"] == Role.ASSISTANT and (i != len(messages) - 1):
|
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
||||||
messages[i]["content"] = self.remove_thought(messages[i]["content"])
|
|
||||||
|
|
||||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
|
||||||
prompt_ids = []
|
if (
|
||||||
for encoded_ids in encoded_messages[:-1]:
|
self.thought_words[0] not in messages[-1]["content"]
|
||||||
prompt_ids += encoded_ids
|
|
||||||
|
|
||||||
if not enable_thinking and (
|
|
||||||
messages[-1]["role"] == Role.ASSISTANT
|
|
||||||
and self.thought_words[0] not in messages[-1]["content"]
|
|
||||||
and self.thought_words[1] not in messages[-1]["content"]
|
and self.thought_words[1] not in messages[-1]["content"]
|
||||||
):
|
):
|
||||||
prompt_ids += self.get_thought_word_ids(tokenizer)
|
if not self.enable_thinking:
|
||||||
|
prompt_ids = prompt_ids + self.get_thought_word_ids(tokenizer)
|
||||||
|
else:
|
||||||
|
response_ids = self.get_thought_word_ids(tokenizer) + response_ids
|
||||||
|
|
||||||
response_ids = encoded_messages[-1]
|
|
||||||
return prompt_ids, response_ids
|
return prompt_ids, response_ids
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -436,15 +431,16 @@ class ReasoningTemplate(Template):
|
|||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
) -> list[tuple[list[int], list[int]]]:
|
) -> list[tuple[list[int], list[int]]]:
|
||||||
messages = deepcopy(messages)
|
|
||||||
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
||||||
for i in range(len(messages) - 1):
|
for i in range(0, len(messages), 2):
|
||||||
if (
|
if (
|
||||||
messages[i + 1]["role"] == Role.ASSISTANT
|
self.thought_words[0] not in messages[i + 1]["content"]
|
||||||
and self.thought_words[0] not in messages[i + 1]["content"]
|
|
||||||
and self.thought_words[1] not in messages[i + 1]["content"]
|
and self.thought_words[1] not in messages[i + 1]["content"]
|
||||||
):
|
):
|
||||||
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
|
if not self.enable_thinking:
|
||||||
|
encoded_messages[i] += self.get_thought_word_ids(tokenizer)
|
||||||
|
else:
|
||||||
|
encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
|
||||||
|
|
||||||
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
||||||
|
|
||||||
@ -467,6 +463,7 @@ def register_template(
|
|||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
replace_jinja_template: bool = False,
|
replace_jinja_template: bool = False,
|
||||||
|
enable_thinking: bool = True,
|
||||||
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
|
||||||
template_class: type["Template"] = Template,
|
template_class: type["Template"] = Template,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -513,6 +510,7 @@ def register_template(
|
|||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos,
|
replace_eos=replace_eos,
|
||||||
replace_jinja_template=replace_jinja_template,
|
replace_jinja_template=replace_jinja_template,
|
||||||
|
enable_thinking=enable_thinking,
|
||||||
mm_plugin=mm_plugin,
|
mm_plugin=mm_plugin,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -549,6 +547,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
|||||||
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
|
messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
|
||||||
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
|
assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
|
||||||
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
|
assistant_slot = assistant_slot[len(prefix) + len(user_slot) :]
|
||||||
|
template_class = ReasoningTemplate if "<think>" in assistant_slot else Template
|
||||||
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
|
assistant_slot = assistant_slot.replace("<think>", "").replace("</think>", "").lstrip("\n") # remove thought tags
|
||||||
|
|
||||||
if len(user_slot) > len(user_slot_empty_system):
|
if len(user_slot) > len(user_slot_empty_system):
|
||||||
@ -558,7 +557,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
|||||||
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
|
else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
|
||||||
default_system = ""
|
default_system = ""
|
||||||
|
|
||||||
return Template(
|
return template_class(
|
||||||
format_user=StringFormatter(slots=[user_slot]),
|
format_user=StringFormatter(slots=[user_slot]),
|
||||||
format_assistant=StringFormatter(slots=[assistant_slot]),
|
format_assistant=StringFormatter(slots=[assistant_slot]),
|
||||||
format_system=StringFormatter(slots=[system_slot]),
|
format_system=StringFormatter(slots=[system_slot]),
|
||||||
@ -572,6 +571,7 @@ def parse_template(tokenizer: "PreTrainedTokenizer") -> "Template":
|
|||||||
efficient_eos=False,
|
efficient_eos=False,
|
||||||
replace_eos=False,
|
replace_eos=False,
|
||||||
replace_jinja_template=False,
|
replace_jinja_template=False,
|
||||||
|
enable_thinking=True,
|
||||||
mm_plugin=get_mm_plugin(name="base"),
|
mm_plugin=get_mm_plugin(name="base"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -600,6 +600,11 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
|
template.format_function = FunctionFormatter(slots=default_slots, tool_format=data_args.tool_format)
|
||||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||||
|
|
||||||
|
if data_args.default_system is not None:
|
||||||
|
logger.info_rank0(f"Using default system message: {data_args.default_system}.")
|
||||||
|
template.default_system = data_args.default_system
|
||||||
|
|
||||||
|
template.enable_thinking = data_args.enable_thinking
|
||||||
template.fix_special_tokens(tokenizer)
|
template.fix_special_tokens(tokenizer)
|
||||||
template.fix_jinja_template(tokenizer)
|
template.fix_jinja_template(tokenizer)
|
||||||
return template
|
return template
|
||||||
|
@ -115,6 +115,14 @@ class DataArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||||
)
|
)
|
||||||
|
default_system: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Override the default system message in the template."},
|
||||||
|
)
|
||||||
|
enable_thinking: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
||||||
|
)
|
||||||
tokenized_path: Optional[str] = field(
|
tokenized_path: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from transformers import GenerationConfig
|
from transformers import GenerationConfig
|
||||||
|
|
||||||
@ -62,18 +62,10 @@ class GeneratingArguments:
|
|||||||
default=1.0,
|
default=1.0,
|
||||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||||
)
|
)
|
||||||
default_system: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Default system message to use in chat completion."},
|
|
||||||
)
|
|
||||||
skip_special_tokens: bool = field(
|
skip_special_tokens: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to remove special tokens in the decoding."},
|
metadata={"help": "Whether or not to remove special tokens in the decoding."},
|
||||||
)
|
)
|
||||||
enable_thinking: bool = field(
|
|
||||||
default=True,
|
|
||||||
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
|
def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]:
|
||||||
args = asdict(self)
|
args = asdict(self)
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from transformers.utils import is_torch_npu_available
|
from transformers.utils import is_torch_npu_available
|
||||||
@ -68,6 +69,14 @@ def _format_response(text: str, lang: str, escape_html: bool, thought_words: tup
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def update_attr(obj: Any, name: str, value: Any):
|
||||||
|
old_value = getattr(obj, name, None)
|
||||||
|
setattr(obj, name, value)
|
||||||
|
yield
|
||||||
|
setattr(obj, name, old_value)
|
||||||
|
|
||||||
|
|
||||||
class WebChatModel(ChatModel):
|
class WebChatModel(ChatModel):
|
||||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
@ -198,35 +207,35 @@ class WebChatModel(ChatModel):
|
|||||||
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
|
Inputs: infer.chatbot, infer.messages, infer.system, infer.tools, infer.image, infer.video, ...
|
||||||
Output: infer.chatbot, infer.messages
|
Output: infer.chatbot, infer.messages
|
||||||
"""
|
"""
|
||||||
chatbot.append({"role": "assistant", "content": ""})
|
with update_attr(self.engine.template, "enable_thinking", enable_thinking):
|
||||||
response = ""
|
chatbot.append({"role": "assistant", "content": ""})
|
||||||
for new_text in self.stream_chat(
|
response = ""
|
||||||
messages,
|
for new_text in self.stream_chat(
|
||||||
system,
|
messages,
|
||||||
tools,
|
system,
|
||||||
images=[image] if image else None,
|
tools,
|
||||||
videos=[video] if video else None,
|
images=[image] if image else None,
|
||||||
audios=[audio] if audio else None,
|
videos=[video] if video else None,
|
||||||
max_new_tokens=max_new_tokens,
|
audios=[audio] if audio else None,
|
||||||
top_p=top_p,
|
max_new_tokens=max_new_tokens,
|
||||||
temperature=temperature,
|
top_p=top_p,
|
||||||
skip_special_tokens=skip_special_tokens,
|
temperature=temperature,
|
||||||
enable_thinking=enable_thinking,
|
skip_special_tokens=skip_special_tokens,
|
||||||
):
|
):
|
||||||
response += new_text
|
response += new_text
|
||||||
if tools:
|
if tools:
|
||||||
result = self.engine.template.extract_tool(response)
|
result = self.engine.template.extract_tool(response)
|
||||||
else:
|
else:
|
||||||
result = response
|
result = response
|
||||||
|
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
|
tool_calls = [{"name": tool.name, "arguments": json.loads(tool.arguments)} for tool in result]
|
||||||
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
|
tool_calls = json.dumps(tool_calls, ensure_ascii=False)
|
||||||
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
output_messages = messages + [{"role": Role.FUNCTION.value, "content": tool_calls}]
|
||||||
bot_text = "```json\n" + tool_calls + "\n```"
|
bot_text = "```json\n" + tool_calls + "\n```"
|
||||||
else:
|
else:
|
||||||
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
output_messages = messages + [{"role": Role.ASSISTANT.value, "content": result}]
|
||||||
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
|
bot_text = _format_response(result, lang, escape_html, self.engine.template.thought_words)
|
||||||
|
|
||||||
chatbot[-1] = {"role": "assistant", "content": bot_text}
|
chatbot[-1] = {"role": "assistant", "content": bot_text}
|
||||||
yield chatbot, output_messages
|
yield chatbot, output_messages
|
||||||
|
@ -126,29 +126,50 @@ def test_encode_multiturn(use_fast: bool):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
@pytest.mark.parametrize("use_fast", [True, False])
|
||||||
def test_reasoning_encode_oneturn(use_fast: bool):
|
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||||
|
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||||
|
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
|
||||||
|
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
|
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||||
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, MESSAGES)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
f"{MESSAGES[1]['content']}<|im_end|>\n"
|
||||||
"<|im_start|>user\n你好<|im_end|>\n"
|
f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
||||||
)
|
)
|
||||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
answer_str = f"{messages[3]['content']}<|im_end|>\n"
|
||||||
|
if not cot_messages:
|
||||||
|
if enable_thinking:
|
||||||
|
answer_str = "<think>\n\n</think>\n\n" + answer_str
|
||||||
|
else:
|
||||||
|
prompt_str = prompt_str + "<think>\n\n</think>\n\n"
|
||||||
|
|
||||||
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
_check_tokenization(tokenizer, (prompt_ids, answer_ids), (prompt_str, answer_str))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
@pytest.mark.parametrize("use_fast", [True, False])
|
||||||
def test_reasoning_encode_multiturn(use_fast: bool):
|
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||||
|
@pytest.mark.parametrize("enable_thinking", [True, False])
|
||||||
|
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
|
||||||
|
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="qwen3"))
|
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
prompt_str_1 = "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
encoded_pairs = template.encode_multiturn(tokenizer, messages)
|
||||||
answer_str_1 = "I am fine!<|im_end|>\n"
|
prompt_str_1 = f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
prompt_str_2 = "<|im_start|>user\n你好<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
|
||||||
answer_str_2 = "很高兴认识你!<|im_end|>\n"
|
prompt_str_2 = f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
|
||||||
|
if not cot_messages:
|
||||||
|
if enable_thinking:
|
||||||
|
answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
|
||||||
|
answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
|
||||||
|
else:
|
||||||
|
prompt_str_1 = prompt_str_1 + "<think>\n\n</think>\n\n"
|
||||||
|
prompt_str_2 = prompt_str_2 + "<think>\n\n</think>\n\n"
|
||||||
|
|
||||||
_check_tokenization(
|
_check_tokenization(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
|
(encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]),
|
||||||
@ -193,12 +214,12 @@ def test_get_stop_token_ids():
|
|||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
@pytest.mark.parametrize("use_fast", [True, False])
|
||||||
def test_gemma_template(use_fast: bool):
|
def test_gemma_template(use_fast: bool):
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<bos><start_of_turn>user\nHow are you<end_of_turn>\n"
|
f"<bos><start_of_turn>user\n{MESSAGES[0]['content']}<end_of_turn>\n"
|
||||||
"<start_of_turn>model\nI am fine!<end_of_turn>\n"
|
f"<start_of_turn>model\n{MESSAGES[1]['content']}<end_of_turn>\n"
|
||||||
"<start_of_turn>user\n你好<end_of_turn>\n"
|
f"<start_of_turn>user\n{MESSAGES[2]['content']}<end_of_turn>\n"
|
||||||
"<start_of_turn>model\n"
|
"<start_of_turn>model\n"
|
||||||
)
|
)
|
||||||
answer_str = "很高兴认识你!<end_of_turn>\n"
|
answer_str = f"{MESSAGES[3]['content']}<end_of_turn>\n"
|
||||||
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
|
_check_template("google/gemma-3-4b-it", "gemma", prompt_str, answer_str, use_fast)
|
||||||
|
|
||||||
|
|
||||||
@ -206,12 +227,12 @@ def test_gemma_template(use_fast: bool):
|
|||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
@pytest.mark.parametrize("use_fast", [True, False])
|
||||||
def test_llama3_template(use_fast: bool):
|
def test_llama3_template(use_fast: bool):
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|>"
|
f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[0]['content']}<|eot_id|>"
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\nI am fine!<|eot_id|>"
|
f"<|start_header_id|>assistant<|end_header_id|>\n\n{MESSAGES[1]['content']}<|eot_id|>"
|
||||||
"<|start_header_id|>user<|end_header_id|>\n\n你好<|eot_id|>"
|
f"<|start_header_id|>user<|end_header_id|>\n\n{MESSAGES[2]['content']}<|eot_id|>"
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
answer_str = "很高兴认识你!<|eot_id|>"
|
answer_str = f"{MESSAGES[3]['content']}<|eot_id|>"
|
||||||
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
|
_check_template("meta-llama/Meta-Llama-3-8B-Instruct", "llama3", prompt_str, answer_str, use_fast)
|
||||||
|
|
||||||
|
|
||||||
@ -220,12 +241,12 @@ def test_llama3_template(use_fast: bool):
|
|||||||
)
|
)
|
||||||
def test_llama4_template(use_fast: bool):
|
def test_llama4_template(use_fast: bool):
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<|begin_of_text|><|header_start|>user<|header_end|>\n\nHow are you<|eot|>"
|
f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{MESSAGES[0]['content']}<|eot|>"
|
||||||
"<|header_start|>assistant<|header_end|>\n\nI am fine!<|eot|>"
|
f"<|header_start|>assistant<|header_end|>\n\n{MESSAGES[1]['content']}<|eot|>"
|
||||||
"<|header_start|>user<|header_end|>\n\n你好<|eot|>"
|
f"<|header_start|>user<|header_end|>\n\n{MESSAGES[2]['content']}<|eot|>"
|
||||||
"<|header_start|>assistant<|header_end|>\n\n"
|
"<|header_start|>assistant<|header_end|>\n\n"
|
||||||
)
|
)
|
||||||
answer_str = "很高兴认识你!<|eot|>"
|
answer_str = f"{MESSAGES[3]['content']}<|eot|>"
|
||||||
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
|
_check_template(TINY_LLAMA4, "llama4", prompt_str, answer_str, use_fast)
|
||||||
|
|
||||||
|
|
||||||
@ -234,12 +255,12 @@ def test_llama4_template(use_fast: bool):
|
|||||||
)
|
)
|
||||||
def test_phi4_template(use_fast: bool):
|
def test_phi4_template(use_fast: bool):
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<|im_start|>user<|im_sep|>How are you<|im_end|>"
|
f"<|im_start|>user<|im_sep|>{MESSAGES[0]['content']}<|im_end|>"
|
||||||
"<|im_start|>assistant<|im_sep|>I am fine!<|im_end|>"
|
f"<|im_start|>assistant<|im_sep|>{MESSAGES[1]['content']}<|im_end|>"
|
||||||
"<|im_start|>user<|im_sep|>你好<|im_end|>"
|
f"<|im_start|>user<|im_sep|>{MESSAGES[2]['content']}<|im_end|>"
|
||||||
"<|im_start|>assistant<|im_sep|>"
|
"<|im_start|>assistant<|im_sep|>"
|
||||||
)
|
)
|
||||||
answer_str = "很高兴认识你!<|im_end|>"
|
answer_str = f"{MESSAGES[3]['content']}<|im_end|>"
|
||||||
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
|
_check_template("microsoft/phi-4", "phi4", prompt_str, answer_str, use_fast)
|
||||||
|
|
||||||
|
|
||||||
@ -247,34 +268,30 @@ def test_phi4_template(use_fast: bool):
|
|||||||
def test_qwen2_5_template(use_fast: bool):
|
def test_qwen2_5_template(use_fast: bool):
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
|
||||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
f"<|im_start|>user\n{MESSAGES[0]['content']}<|im_end|>\n"
|
||||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
|
||||||
"<|im_start|>user\n你好<|im_end|>\n"
|
f"<|im_start|>user\n{MESSAGES[2]['content']}<|im_end|>\n"
|
||||||
"<|im_start|>assistant\n"
|
"<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
answer_str = f"{MESSAGES[3]['content']}<|im_end|>\n"
|
||||||
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
_check_template("Qwen/Qwen2.5-7B-Instruct", "qwen", prompt_str, answer_str, use_fast)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
@pytest.mark.parametrize("use_fast", [True, False])
|
||||||
def test_qwen3_template(use_fast: bool):
|
@pytest.mark.parametrize("cot_messages", [True, False])
|
||||||
|
def test_qwen3_template(use_fast: bool, cot_messages: bool):
|
||||||
|
messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
|
||||||
prompt_str = (
|
prompt_str = (
|
||||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n"
|
||||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
f"<|im_start|>assistant\n{MESSAGES[1]['content']}<|im_end|>\n"
|
||||||
"<|im_start|>user\n你好<|im_end|>\n"
|
f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n"
|
||||||
"<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
|
||||||
)
|
|
||||||
answer_str = "很高兴认识你!<|im_end|>\n"
|
|
||||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast)
|
|
||||||
|
|
||||||
prompt_str = (
|
|
||||||
"<|im_start|>user\nHow are you<|im_end|>\n"
|
|
||||||
"<|im_start|>assistant\nI am fine!<|im_end|>\n"
|
|
||||||
"<|im_start|>user\n你好<|im_end|>\n"
|
|
||||||
"<|im_start|>assistant\n"
|
"<|im_start|>assistant\n"
|
||||||
)
|
)
|
||||||
answer_str = "<think>\n模型思考内容\n</think>\n\n很高兴认识你!<|im_end|>\n"
|
answer_str = f"{messages[3]['content']}<|im_end|>\n"
|
||||||
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=MESSAGES_WITH_THOUGHT)
|
if not cot_messages:
|
||||||
|
answer_str = "<think>\n\n</think>\n\n" + answer_str
|
||||||
|
|
||||||
|
_check_template("Qwen/Qwen3-8B", "qwen3", prompt_str, answer_str, use_fast, messages=messages)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_llama3_template():
|
def test_parse_llama3_template():
|
||||||
@ -293,6 +310,7 @@ def test_parse_llama3_template():
|
|||||||
def test_parse_qwen_template():
|
def test_parse_qwen_template():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct", token=HF_TOKEN)
|
||||||
template = parse_template(tokenizer)
|
template = parse_template(tokenizer)
|
||||||
|
assert template.__class__.__name__ == "Template"
|
||||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||||
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
||||||
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||||
@ -303,6 +321,7 @@ def test_parse_qwen_template():
|
|||||||
def test_parse_qwen3_template():
|
def test_parse_qwen3_template():
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", token=HF_TOKEN)
|
||||||
template = parse_template(tokenizer)
|
template = parse_template(tokenizer)
|
||||||
|
assert template.__class__.__name__ == "ReasoningTemplate"
|
||||||
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
assert template.format_user.slots == ["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]
|
||||||
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
assert template.format_assistant.slots == ["{{content}}<|im_end|>\n"]
|
||||||
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
assert template.format_system.slots == ["<|im_start|>system\n{{content}}<|im_end|>\n"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user