diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index 938e9bcc..cc815d1b 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -53,7 +53,7 @@ class ChatModel: pad_token_id=self.tokenizer.pad_token_id )) - if int(num_return_sequences) > 1: + if isinstance(num_return_sequences, int) and num_return_sequences > 1: generating_args["do_sample"] = True if max_length: