From 65ef0ff491d4f0ecc84abb68368e56a737d1383e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 3 Aug 2023 16:15:38 +0800 Subject: [PATCH] fix qwen inference Former-commit-id: ea30da4794c094ce5442b7b2936de0686a1739eb --- src/llmtuner/chat/stream_chat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index b1ded67a..f39ed960 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -1,7 +1,8 @@ import torch +from types import MethodType from typing import Any, Dict, Generator, List, Optional, Tuple from threading import Thread -from transformers import TextIteratorStreamer +from transformers import PreTrainedModel, TextIteratorStreamer from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria from llmtuner.extras.template import get_template @@ -20,6 +21,7 @@ class ChatModel: self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words ] self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) + self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model def process_args( self,