fix qwen inference

This commit is contained in:
hiyouga
2023-08-03 16:15:38 +08:00
parent 87f8f830e2
commit ea30da4794

View File

@@ -1,7 +1,8 @@
import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import TextIteratorStreamer
from transformers import PreTrainedModel, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
from llmtuner.extras.template import get_template
@@ -20,6 +21,7 @@ class ChatModel:
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
]
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
def process_args(
self,