fix resize vocab at inference #3022

This commit is contained in:
hiyouga
2024-04-03 18:14:24 +08:00
parent ce77d98872
commit 148bda353f
9 changed files with 31 additions and 40 deletions

View File

@@ -9,7 +9,7 @@ from transformers import GenerationConfig, TextIteratorStreamer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import load_model_and_tokenizer
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
@@ -30,11 +30,12 @@ class HuggingfaceEngine(BaseEngine):
generating_args: "GeneratingArguments",
) -> None:
self.can_generate = finetuning_args.stage == "sft"
self.model, self.tokenizer = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.tokenizer = load_tokenizer(model_args)
self.tokenizer.padding_side = "left" if self.can_generate else "right"
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args.template)
self.model = load_model(
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.generating_args = generating_args.to_dict()
@staticmethod