support resize embeddings #1786

This commit is contained in:
hiyouga
2023-12-11 17:50:02 +08:00
parent 9ce1b0e2f2
commit 64744dde89
2 changed files with 16 additions and 1 deletions

View File

@@ -28,7 +28,7 @@ from llmtuner.extras.packages import is_flash_attn2_available
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.hparams import FinetuningArguments
from llmtuner.model.adapter import init_adapter
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@@ -185,6 +185,9 @@ def load_model_and_tokenizer(
**config_kwargs
)
# Resize token embeddings
resize_embedding_layer(model, tokenizer)
# Disable custom generate method (for Qwen and Baichuan2)
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)