mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
fix qwen inference
Former-commit-id: ea30da4794c094ce5442b7b2936de0686a1739eb
This commit is contained in:
parent
9c84c4ed5d
commit
65ef0ff491
@ -1,7 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from types import MethodType
|
||||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import PreTrainedModel, TextIteratorStreamer
|
||||||
|
|
||||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
|
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopwords_criteria
|
||||||
from llmtuner.extras.template import get_template
|
from llmtuner.extras.template import get_template
|
||||||
@ -20,6 +21,7 @@ class ChatModel:
|
|||||||
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
|
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
|
||||||
]
|
]
|
||||||
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
||||||
|
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
|
||||||
|
|
||||||
def process_args(
|
def process_args(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user