fix qwen inference

Former-commit-id: 2c5fe45ce1405124f12ecd20e263b5538af97972
This commit is contained in:
hiyouga 2023-08-03 16:15:38 +08:00
parent 2e19afedb8
commit e434348216

View File

@ -1,7 +1,8 @@
import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from transformers import PreTrainedModel, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
from llmtuner.extras.template import get_template
@ -20,6 +21,7 @@ class ChatModel:
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
]
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
def process_args(
self,