fix qwen inference

Former-commit-id: 2780792754b484bf4d42af5ebbc51c7ed2181ce9
This commit is contained in:
hiyouga 2023-08-03 16:31:55 +08:00
parent 65ef0ff491
commit 788d1250c1
2 changed files with 7 additions and 7 deletions

View File

@ -17,9 +17,7 @@ class ChatModel:
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template(data_args.template) self.template = get_template(data_args.template)
self.source_prefix = data_args.source_prefix self.source_prefix = data_args.source_prefix
self.stop_ids = [ self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
]
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words)) self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model

View File

@ -6,13 +6,14 @@ from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase
) )
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import reset_logging, get_logger from llmtuner.extras.logging import reset_logging, get_logger
@ -22,6 +23,7 @@ from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter from llmtuner.tuner.core.adapter import init_adapter
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from llmtuner.hparams import ModelArguments from llmtuner.hparams import ModelArguments
@ -40,7 +42,7 @@ def load_model_and_tokenizer(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: ) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r""" r"""
Loads pretrained model and tokenizer. Loads pretrained model and tokenizer.