mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix qwen inference
Former-commit-id: ea30da4794c094ce5442b7b2936de0686a1739eb
This commit is contained in:
parent
9c84c4ed5d
commit
65ef0ff491
@ -1,7 +1,8 @@
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers import PreTrainedModel, TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
|
||||
from llmtuner.extras.template import get_template
|
||||
@ -20,6 +21,7 @@ class ChatModel:
|
||||
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
|
||||
]
|
||||
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
|
||||
|
||||
def process_args(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user