mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
fix qwen inference
Former-commit-id: 2780792754b484bf4d42af5ebbc51c7ed2181ce9
This commit is contained in:
parent
65ef0ff491
commit
788d1250c1
@ -17,9 +17,7 @@ class ChatModel:
|
|||||||
self.model = dispatch_model(self.model)
|
self.model = dispatch_model(self.model)
|
||||||
self.template = get_template(data_args.template)
|
self.template = get_template(data_args.template)
|
||||||
self.source_prefix = data_args.source_prefix
|
self.source_prefix = data_args.source_prefix
|
||||||
self.stop_ids = [
|
self.stop_ids = self.tokenizer.convert_tokens_to_ids(self.template.stop_words)
|
||||||
self.tokenizer.encode(word, add_special_tokens=False)[0] for word in self.template.stop_words
|
|
||||||
]
|
|
||||||
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
self.tokenizer.add_special_tokens(dict(additional_special_tokens=self.template.stop_words))
|
||||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
|
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # a monkey fix for qwen model
|
||||||
|
|
||||||
|
@ -6,13 +6,14 @@ from transformers import (
|
|||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BitsAndBytesConfig
|
BitsAndBytesConfig,
|
||||||
|
PretrainedConfig,
|
||||||
|
PreTrainedModel,
|
||||||
|
PreTrainedTokenizerBase
|
||||||
)
|
)
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
|
||||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
@ -22,6 +23,7 @@ from llmtuner.hparams import FinetuningArguments
|
|||||||
from llmtuner.tuner.core.adapter import init_adapter
|
from llmtuner.tuner.core.adapter import init_adapter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import ModelArguments
|
from llmtuner.hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +42,7 @@ def load_model_and_tokenizer(
|
|||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
is_trainable: Optional[bool] = False,
|
is_trainable: Optional[bool] = False,
|
||||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained model and tokenizer.
|
Loads pretrained model and tokenizer.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user