diff --git a/src/llmtuner/chat/vllm_engine.py b/src/llmtuner/chat/vllm_engine.py index faf8c9fe..aaaad2f1 100644 --- a/src/llmtuner/chat/vllm_engine.py +++ b/src/llmtuner/chat/vllm_engine.py @@ -100,8 +100,9 @@ class VllmEngine(BaseEngine): max_new_tokens = input_kwargs.pop("max_new_tokens", None) stop = input_kwargs.pop("stop", None) + max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"] if max_length: - max_tokens = max_length - prompt_length + max_tokens = max_length - prompt_length if max_length > prompt_length else 1 if max_new_tokens: max_tokens = max_new_tokens