fix resize vocab at inference #3022

This commit is contained in:
hiyouga
2024-04-03 18:14:24 +08:00
parent ce77d98872
commit 148bda353f
9 changed files with 31 additions and 40 deletions

View File

@@ -10,7 +10,7 @@ from transformers.utils.versions import require_version
from ..extras.logging import get_logger
from ..extras.packages import is_galore_available
from ..hparams import FinetuningArguments, ModelArguments
from ..model import find_all_linear_modules, load_model_and_tokenizer, load_valuehead_params
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
if is_galore_available():
@@ -87,16 +87,18 @@ def create_ref_model(
)
ref_model_args = ModelArguments(**ref_model_args_dict)
ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
ref_model, _ = load_model_and_tokenizer(
ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
tokenizer = load_tokenizer(ref_model_args)
ref_model = load_model(
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from {}".format(finetuning_args.ref_model))
else:
if finetuning_args.finetuning_type == "lora":
ref_model = None
else:
ref_model, _ = load_model_and_tokenizer(
model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
tokenizer = load_tokenizer(model_args)
ref_model = load_model(
tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
)
logger.info("Created reference model from the model itself.")
@@ -141,8 +143,9 @@ def create_reward_model(
)
reward_model_args = ModelArguments(**reward_model_args_dict)
reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
reward_model, _ = load_model_and_tokenizer(
reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
tokenizer = load_tokenizer(reward_model_args)
reward_model = load_model(
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
)
logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")