fix qwen inference

Former-commit-id: 2c5fe45ce1405124f12ecd20e263b5538af97972
This commit is contained in:
hiyouga 2023-08-03 16:15:38 +08:00
parent 2e19afedb8
commit e434348216

View File

@ -1,7 +1,8 @@
import torch import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import PreTrainedModel, TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
from llmtuner.extras.template import get_template from llmtuner.extras.template import get_template
@ -20,6 +21,7 @@ class ChatModel:
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
] ]
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
def process_args( def process_args(
self, self,