fix qwen inference

Former-commit-id: ea30da4794c094ce5442b7b2936de0686a1739eb
This commit is contained in:
hiyouga 2023-08-03 16:15:38 +08:00
parent 9c84c4ed5d
commit 65ef0ff491

View File

@ -1,7 +1,8 @@
import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from transformers import PreTrainedModel, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
from llmtuner.extras.template import get_template
@ -20,6 +21,7 @@ class ChatModel:
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
]
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
def process_args(
self,