mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 04:32:50 +08:00
support resize embeddings #1786
Former-commit-id: 64744dde89ccb9a24a46985a99151ad2dde03919
This commit is contained in:
parent
5819eb7121
commit
9ead5a2d21
@ -28,7 +28,7 @@ from llmtuner.extras.packages import is_flash_attn2_available
|
|||||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||||
from llmtuner.hparams import FinetuningArguments
|
from llmtuner.hparams import FinetuningArguments
|
||||||
from llmtuner.model.adapter import init_adapter
|
from llmtuner.model.adapter import init_adapter
|
||||||
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training
|
from llmtuner.model.utils import load_valuehead_params, prepare_model_for_training, resize_embedding_layer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
@ -185,6 +185,9 @@ def load_model_and_tokenizer(
|
|||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Resize token embeddings
|
||||||
|
resize_embedding_layer(model, tokenizer)
|
||||||
|
|
||||||
# Disable custom generate method (for Qwen and Baichuan2)
|
# Disable custom generate method (for Qwen and Baichuan2)
|
||||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||||
|
@ -11,6 +11,7 @@ from llmtuner.hparams import ModelArguments, FinetuningArguments
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||||
from llmtuner.hparams import DataArguments
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
@ -181,3 +182,14 @@ def prepare_model_for_training(
|
|||||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
|
r"""
|
||||||
|
Resize token embeddings.
|
||||||
|
"""
|
||||||
|
old_vocab_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
new_vocab_size = len(tokenizer)
|
||||||
|
if new_vocab_size != old_vocab_size:
|
||||||
|
model.resize_token_embeddings(new_vocab_size, pad_to_multiple_of=64)
|
||||||
|
logger.info("Resized embedding tokens from {} to {}.".format(old_vocab_size, new_vocab_size))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user