From c8890c32db8c8a6c83fcee684de05c263f466cc8 Mon Sep 17 00:00:00 2001 From: Kingsley Date: Mon, 27 Apr 2026 00:32:44 +0800 Subject: [PATCH] [data] support discard history cot for multiturn (#10435) --- src/llamafactory/data/processor/supervised.py | 3 +- src/llamafactory/data/template.py | 30 +++++++---------- tests/data/test_template.py | 33 +++++++++++++++++++ 3 files changed, 47 insertions(+), 19 deletions(-) diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py index 25eb82adb..cc7bf3a96 100644 --- a/src/llamafactory/data/processor/supervised.py +++ b/src/llamafactory/data/processor/supervised.py @@ -61,7 +61,8 @@ class SupervisedDatasetProcessor(DatasetProcessor): input_ids, labels = self.template.mm_plugin.process_token_ids( [], [], images, videos, audios, self.tokenizer, self.processor ) - encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools) + discarding_history_cot = self.data_args.mask_history and not self.template.preserve_thinking + encoded_pairs = self.template.encode_multiturn(self.tokenizer, messages, system, tools, discarding_history_cot) total_length = len(input_ids) + (1 if self.template.efficient_eos else 0) if self.data_args.mask_history: encoded_pairs = encoded_pairs[::-1] # high priority for last turns diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index ff592db05..c8b4b0007 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -79,6 +79,7 @@ class Template: messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, + discarding_history_cot: bool = False, # only effect reasoning template ) -> list[tuple[list[int], list[int]]]: r"""Return multiple pairs of token ids representing prompts and responses respectively.""" encoded_messages = self._encode(tokenizer, messages, system, tools) @@ -441,14 +442,24 @@ class ReasoningTemplate(Template): messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, + discarding_history_cot: bool = False, ) -> list[tuple[list[int], list[int]]]: messages = deepcopy(messages) if self.enable_thinking is False: # remove all cot for i in range(1, len(messages), 2): messages[i]["content"] = self.remove_thought(messages[i]["content"]) + if discarding_history_cot: + for i in range(1, len(messages) - 2, 2): # preserve the last cot + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + encoded_messages = self._encode(tokenizer, messages, system, tools) - for i in range(0, len(messages), 2): + if discarding_history_cot: + turn_indices = [len(messages) - 2] + else: + turn_indices = range(0, len(messages), 2) + + for i in turn_indices: if ( self.thought_words[0].strip() not in messages[i + 1]["content"] and self.thought_words[1].strip() not in messages[i + 1]["content"] @@ -2135,23 +2146,6 @@ register_template( ) -# copied from qwen3_5_nothink template -register_template( - name="qwen3_6_nothink", - format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), - format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), - format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"), - format_observation=StringFormatter( - slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] - ), - format_tools=ToolFormatter(tool_format="qwen3_5"), - stop_words=["<|im_end|>"], - replace_eos=True, - mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), -) - - register_template( name="sailor", format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), diff --git a/tests/data/test_template.py b/tests/data/test_template.py index b9d9ab2d8..e44210804 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -181,6 +181,39 @@ def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool): (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2), ) +@pytest.mark.runs_on(["cpu", "mps"]) +@pytest.mark.parametrize("enable_thinking", [True, False, None]) +@pytest.mark.parametrize("discarding_history_cot", [True, False]) +def test_reasoning_encode_multiturn_discarding_history_cot(enable_thinking: bool, discarding_history_cot: bool): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") + data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking) + template = get_template_and_fix_tokenizer(tokenizer, data_args) + encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT, discarding_history_cot=discarding_history_cot) + + prompt_str_1 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[0]['content']}<|im_end|>\n<|im_start|>assistant\n" + prompt_str_2 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[2]['content']}<|im_end|>\n<|im_start|>assistant\n" + + if enable_thinking is False: + answer_str_1 = f"{MESSAGES[1]['content']}<|im_end|>\n" + answer_str_2 = f"{MESSAGES[3]['content']}<|im_end|>\n" + if discarding_history_cot: + prompt_str_2 = prompt_str_2 + "\n\n\n\n" + else: + prompt_str_1 = prompt_str_1 + "\n\n\n\n" + prompt_str_2 = prompt_str_2 + "\n\n\n\n" + else: + if discarding_history_cot: + answer_str_1 = f"{MESSAGES[1]['content']}<|im_end|>\n" + else: + answer_str_1 = f"{MESSAGES_WITH_THOUGHT[1]['content']}<|im_end|>\n" + answer_str_2 = f"{MESSAGES_WITH_THOUGHT[3]['content']}<|im_end|>\n" + + _check_tokenization( + tokenizer, + (encoded_pairs[0][0], encoded_pairs[0][1], encoded_pairs[1][0], encoded_pairs[1][1]), + (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2), + ) + @pytest.mark.runs_on(["cpu", "mps"]) def test_jinja_template():