[model] adds semantic initialization support for special tokens (#9267)

Co-authored-by: ximingxing <ximingxing@tencent.com>
This commit is contained in:
Ximing Xing
2025-10-14 17:00:48 +08:00
committed by GitHub
parent 3dbca4b533
commit c867e28093
4 changed files with 264 additions and 7 deletions

View File

@@ -14,7 +14,7 @@
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
@@ -30,6 +30,14 @@ logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
"""
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
@@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""Resize token embeddings."""
def _description_based_initialization(
embed_weight: "torch.Tensor",
num_new_tokens: int,
descriptions: dict[str, str],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
add_noise: bool = False,
) -> None:
"""Initialize new token embeddings based on textual descriptions.
For each new token, this function:
1. Tokenizes its description text
2. Gets embeddings of the description tokens
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
descriptions: Dict mapping token string to its description text
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
add_noise: Whether to add Gaussian noise to the initialization
Example:
descriptions = {
"<|START_OF_SVG|>": "Marks the beginning of an SVG document",
"<|END_OF_SVG|>": "Marks the end of an SVG document"
}
"""
embedding_dim = embed_weight.size(1)
for i, desc in enumerate(descriptions.values()):
# Tokenize description text
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
token_ids = tokens["input_ids"][0]
# Move to the same device as embed_weight
device = embed_weight.device
token_ids = token_ids.to(device)
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
if len(valid_token_ids) == 0:
# Fallback: use mean of all existing embeddings
logger.warning_rank0(
f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
else:
# Get embeddings of description tokens and average them
token_embeds = model.get_input_embeddings()(valid_token_ids)
base_embedding = token_embeds.mean(dim=0)
# Add noise if requested (ensure correct device and dtype)
if add_noise:
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
embed_weight[-num_new_tokens + i] = base_embedding + noise
else:
embed_weight[-num_new_tokens + i] = base_embedding
def _initialize_embeddings(
embed_weight: "torch.Tensor",
num_new_tokens: int,
init_method: str,
new_special_tokens_config: Optional[dict],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
) -> None:
"""Single source of truth for embedding initialization.
This function selects the appropriate initialization method and applies it.
Args:
embed_weight: The embedding weight matrix to initialize
num_new_tokens: Number of new tokens added
init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
"""
if init_method == "desc_init" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
)
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
)
else:
if init_method != "noise_init":
logger.warning_rank0(
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
)
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
_noisy_mean_initialization(embed_weight, num_new_tokens)
def resize_embedding_layer(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
new_special_tokens_config: Optional[dict] = None,
init_special_tokens: str = "noise_init",
) -> None:
r"""Resize token embeddings and initialize new tokens.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
@@ -64,8 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info_rank0(
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
)
# Initialize input embeddings
_initialize_embeddings(
model.get_input_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
model.config.vocab_size = new_embedding_size
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@@ -175,7 +175,12 @@ def patch_model(
prepare_valuehead_model(model)
if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer)
resize_embedding_layer(
model,
tokenizer,
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
init_special_tokens=model_args.init_special_tokens,
)
if is_trainable:
if getattr(model.config, "model_type", None) == "gemma3n":