diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 2434016e..3df33c70 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -28,7 +28,7 @@ from llmtuner.extras.packages import is_flash_attn2_available from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments from llmtuner.model.adapter import init_adapter -from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training +from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -185,6 +185,9 @@ def load_model_and_tokenizer( **config_kwargs ) + # Resize token embeddings + resize_embedding_layer(model, tokenizer) + # Disable custom generate method (for Qwen and Baichuan2) if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__): model.generate = MethodType(PreTrainedModel.generate, model) diff --git a/src/llmtuner/model/utils.py b/src/llmtuner/model/utils.py index 42bef35b..b52582f5 100644 --- a/src/llmtuner/model/utils.py +++ b/src/llmtuner/model/utils.py @@ -11,6 +11,7 @@ from llmtuner.hparams import ModelArguments, FinetuningArguments if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel + from transformers.tokenization_utils import PreTrainedTokenizer from llmtuner.hparams import DataArguments @@ -181,3 +182,14 @@ def prepare_model_for_training( output_layer.register_forward_hook(fp32_forward_post_hook) return model + + +def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: + r""" + Resize token embeddings. + """ + old_vocab_size = model.get_input_embeddings().weight.size(0) + new_vocab_size = len(tokenizer) + if new_vocab_size != old_vocab_size: + model.resize_token_embeddings(new_vocab_size, pad_to_multiple_of=64) + logger.info("Resized embedding tokens from {} to {}.".format(old_vocab_size, new_vocab_size))