From 9ead5a2d21dcb270d7e6439a6b4be5cbc44df4d1 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 11 Dec 2023 17:50:02 +0800 Subject: [PATCH] support resize embeddings #1786 Former-commit-id: 64744dde89ccb9a24a46985a99151ad2dde03919 --- src/llmtuner/model/loader.py | 5 ++++- src/llmtuner/model/utils.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 2434016e..3df33c70 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -28,7 +28,7 @@ from llmtuner.extras.packages import is_flash_attn2_available from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments from llmtuner.model.adapter import init_adapter -from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training +from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -185,6 +185,9 @@ def load_model_and_tokenizer( **config_kwargs ) + # Resize token embeddings + resize_embedding_layer(model, tokenizer) + # Disable custom generate method (for Qwen and Baichuan2) if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 42bef35b..b52582f5 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -11,6 +11,7 @@ from llmtuner.hparams import ModelArguments, FinetuningArguments if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel + from transformers.tokenization_utils import PreTrainedTokenizer from llmtuner.hparams import DataArguments @@ -181,3 +182,14 @@ def prepare_model_for_training( output_layer.register_forward_hook(fp32_forward_post_hook) return model + + +def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: + r""" + Resize token embeddings. + """ + old_vocab_size = model.get_input_embeddings().weight.size(0) + new_vocab_size = len(tokenizer) + if new_vocab_size != old_vocab_size: + model.resize_token_embeddings(new_vocab_size, pad_to_multiple_of=64) + logger.info("Resized embedding tokens from {} to {}.".format(old_vocab_size, new_vocab_size))