diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index b1ded67a..f39ed960 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -1,7 +1,8 @@ import torch +from types import MethodType from typing import Any, Dict, Generator, List, Optional, Tuple from threading import Thread -from transformers import TextIteratorStreamer +from transformers import PreTrainedModel, TextIteratorStreamer from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria from llmtuner.extras.template import get_template @@ -20,6 +21,7 @@ class ChatModel: self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words ] self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) + self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model def process_args( self,