support resize embed for zero3

Former-commit-id: a5f6a7f4fb057511428011c37422c535f31b79d2
This commit is contained in:
liuzc 2024-01-16 15:16:20 +08:00
parent 7e16d27fca
commit 61bc5bd0dd

View File

@ -5,6 +5,7 @@ import random
from types import MethodType from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from datasets import load_dataset from datasets import load_dataset
from contextlib import nullcontext
from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
from transformers.integrations import is_deepspeed_zero3_enabled from transformers.integrations import is_deepspeed_zero3_enabled
@ -28,7 +29,7 @@ SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
embedding_dim = embed_weight.size(1) embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(avg_weight[-num_new_tokens:]) noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight embed_weight[-num_new_tokens:] = avg_weight + noise_weight
@ -37,6 +38,11 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke
r""" r"""
Resize token embeddings. Resize token embeddings.
""" """
if is_deepspeed_zero3_enabled():
import deepspeed
with deepspeed.zero.GatheredParameters(model.get_input_embeddings().weight, modifier_rank=None):
current_embedding_size = model.get_input_embeddings().weight.size(0)
else:
current_embedding_size = model.get_input_embeddings().weight.size(0) current_embedding_size = model.get_input_embeddings().weight.size(0)
if len(tokenizer) > current_embedding_size: if len(tokenizer) > current_embedding_size:
if not isinstance(model.get_output_embeddings(), torch.nn.Linear): if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
@ -44,6 +50,15 @@ def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToke
return return
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
if is_deepspeed_zero3_enabled():
import deepspeed
params = [model.get_input_embeddings().weight]
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
params.append(model.get_output_embeddings().weight)
context = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context = nullcontext()
with context:
new_embedding_size = model.get_input_embeddings().weight.size(0) new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
@ -264,9 +279,6 @@ def patch_model(
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
if model_args.resize_vocab: if model_args.resize_vocab:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with vocab resizing.")
_resize_embedding_layer(model, tokenizer) _resize_embedding_layer(model, tokenizer)
if is_trainable: if is_trainable: