From 6a584b40928fb6d69e22c1403db226eb04358a30 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 30 Apr 2025 17:21:30 +0800 Subject: [PATCH] [hparam] add enable think argument (#7928) --- src/llamafactory/chat/hf_engine.py | 3 ++- src/llamafactory/chat/sglang_engine.py | 3 ++- src/llamafactory/chat/vllm_engine.py | 3 ++- src/llamafactory/hparams/generating_args.py | 4 ++++ 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 5f335d8b..8fb08dee 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -106,7 +106,8 @@ class HuggingfaceEngine(BaseEngine): # add thought words to avoid skipping thinking paired_messages = messages + [{"role": "assistant", "content": template.add_thought("")}] system = system or generating_args["default_system"] - enable_thinking = input_kwargs.pop("enable_thinking", True) + enable_thinking = input_kwargs.pop("enable_thinking", None) + enable_thinking = enable_thinking if enable_thinking is not None else generating_args["enable_thinking"] prompt_ids, _ = template.encode_oneturn(tokenizer, paired_messages, system, tools, enable_thinking) prompt_ids, _ = template.mm_plugin.process_token_ids( prompt_ids, diff --git a/src/llamafactory/chat/sglang_engine.py b/src/llamafactory/chat/sglang_engine.py index 44414a05..7af561ca 100644 --- a/src/llamafactory/chat/sglang_engine.py +++ b/src/llamafactory/chat/sglang_engine.py @@ -149,7 +149,8 @@ class SGLangEngine(BaseEngine): # add thought words to avoid skipping thinking paired_messages = messages + [{"role": "assistant", "content": self.template.add_thought("")}] system = system or self.generating_args["default_system"] - enable_thinking = input_kwargs.pop("enable_thinking", True) + enable_thinking = input_kwargs.pop("enable_thinking", None) + enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"] prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking) prompt_length = len(prompt_ids) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 274d12e3..dde67dd6 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -126,7 +126,8 @@ class VllmEngine(BaseEngine): # add thought words to avoid skipping thinking paired_messages = messages + [{"role": "assistant", "content": self.template.add_thought("")}] system = system or self.generating_args["default_system"] - enable_thinking = input_kwargs.pop("enable_thinking", True) + enable_thinking = input_kwargs.pop("enable_thinking", None) + enable_thinking = enable_thinking if enable_thinking is not None else self.generating_args["enable_thinking"] prompt_ids, _ = self.template.encode_oneturn(self.tokenizer, paired_messages, system, tools, enable_thinking) prompt_length = len(prompt_ids) diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index 251822b1..ac377543 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -70,6 +70,10 @@ class GeneratingArguments: default=True, metadata={"help": "Whether or not to remove special tokens in the decoding."}, ) + enable_thinking: bool = field( + default=True, + metadata={"help": "Whether or not to enable thinking mode for reasoning models."}, + ) def to_dict(self, obey_generation_config: bool = False) -> dict[str, Any]: args = asdict(self)