[fix] correctly place new token embeddings when embedding is padded (#10547)

This commit is contained in:
Ximing Xing
2026-06-05 10:47:51 +08:00
committed by GitHub
parent 8a4f6a3da5
commit 0b7aaf8f6a
3 changed files with 341 additions and 41 deletions

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from collections.abc import Iterable
from contextlib import nullcontext
from typing import TYPE_CHECKING, Optional
@@ -29,7 +30,81 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
def get_embedding_vocab_size(model: "PreTrainedModel") -> int:
r"""Get the vocab size from the input embedding layer.
Handles DeepSpeed ZeRO-3 parameter sharding by gathering the embedding weight
before reading its size.
"""
embedding = model.get_input_embeddings()
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
with deepspeed.zero.GatheredParameters([embedding.weight]):
return embedding.weight.size(0)
return embedding.weight.size(0)
def _resolve_new_token_ids(
new_tokens: Optional[Iterable[str]],
tokenizer: "PreTrainedTokenizer",
embed_size: int,
) -> Optional[list[int]]:
r"""Resolve the explicit embedding-row IDs of the newly added tokens.
Relying on ``embed_weight[-num_new_tokens:]`` to locate new tokens is unsafe when
the model embedding was already padded beyond the tokenizer vocab (e.g. Qwen2.5-VL
has vocab 151665 but embedding 151936). In that case the appended tokens land
inside the original padding zone and the tail slice points at the wrong rows.
Args:
new_tokens: Iterable of the newly added token strings.
tokenizer: The tokenizer instance.
embed_size: Current embedding size (upper bound for valid token IDs).
Returns:
A sorted list of unique, in-range token IDs, or ``None`` when no tokens are
given so that callers can fall back to the tail-slice behaviour.
"""
if not new_tokens:
return None
unk_token_id = getattr(tokenizer, "unk_token_id", None)
token_ids: set[int] = set()
for token_str in new_tokens:
token_id = tokenizer.convert_tokens_to_ids(token_str)
if token_id is None or token_id == unk_token_id or not (0 <= token_id < embed_size):
logger.warning_rank0(f"Token '{token_str}' not found or out of range, skipping during init.")
continue
token_ids.add(token_id)
return sorted(token_ids) or None
def _existing_embeddings(
embed_weight: "torch.Tensor", num_new_tokens: int, new_token_ids: Optional[list[int]]
) -> "torch.Tensor":
"""Return the rows treated as 'existing' embeddings used as the init baseline.
Prefers excluding the explicit new-token rows (robust to padding). Falls back to
dropping the last ``num_new_tokens`` rows when no explicit IDs are available.
"""
if new_token_ids:
mask = torch.ones(embed_weight.size(0), dtype=torch.bool, device=embed_weight.device)
mask[torch.as_tensor(new_token_ids, device=embed_weight.device, dtype=torch.long)] = False
return embed_weight[mask]
if num_new_tokens > 0:
return embed_weight[:-num_new_tokens]
return embed_weight
def _noisy_mean_initialization(
embed_weight: "torch.Tensor", num_new_tokens: int, token_ids: Optional[list[int]] = None
) -> None:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
@@ -37,12 +112,23 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
token_ids: Explicit token IDs to initialize. When provided, these exact rows are
written (robust to padding). When ``None``, falls back to the last
``num_new_tokens`` rows.
"""
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
avg_weight = _existing_embeddings(embed_weight, num_new_tokens, token_ids).mean(dim=0, keepdim=True)
if token_ids:
noise_weight = torch.empty(
len(token_ids), embedding_dim, device=embed_weight.device, dtype=embed_weight.dtype
)
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[token_ids] = avg_weight + noise_weight
else:
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def _description_based_initialization(
@@ -51,6 +137,7 @@ def _description_based_initialization(
descriptions: dict[str, str],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
new_token_ids: Optional[list[int]] = None,
add_noise: bool = False,
) -> None:
"""Initialize new token embeddings based on textual descriptions.
@@ -61,6 +148,9 @@ def _description_based_initialization(
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
New tokens are placed by their resolved token ID rather than by tail slicing,
so the initialization is correct even when the embedding matrix was padded.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
@@ -68,6 +158,8 @@ def _description_based_initialization(
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
new_token_ids: IDs of all newly added tokens. Used to exclude not-yet-initialized
rows when averaging description-token embeddings (robust to embedding padding).
add_noise: Whether to add Gaussian noise to the initialization
Example:
@@ -77,38 +169,55 @@ def _description_based_initialization(
}
"""
embedding_dim = embed_weight.size(1)
vocab_size = embed_weight.size(0)
unk_token_id = getattr(tokenizer, "unk_token_id", None)
device = embed_weight.device
# The set of rows that are NOT yet initialized (the newly added tokens). Description
# tokens that fall into this set must be excluded, otherwise we would average garbage.
# `num_new_tokens` (the padded resize delta) is NOT a reliable boundary, so rely on
# the explicit IDs, falling back to resolving them from the description keys.
if new_token_ids is None:
new_token_ids = _resolve_new_token_ids(descriptions.keys(), tokenizer, vocab_size)
new_id_set = set(new_token_ids or [])
fallback_embedding = _existing_embeddings(embed_weight, num_new_tokens, new_token_ids).mean(dim=0)
for token_str, desc in descriptions.items():
# Resolve token ID for correct placement (robust to embedding padding)
token_id = tokenizer.convert_tokens_to_ids(token_str)
if token_id is None or token_id == unk_token_id or not (0 <= token_id < vocab_size):
logger.warning_rank0(f"desc_init: token '{token_str}' not found or out of range, skipping.")
continue
for i, desc in enumerate(descriptions.values()):
# Tokenize description text
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
token_ids = tokens["input_ids"][0]
# Move to the same device as embed_weight
device = embed_weight.device
token_ids = token_ids.to(device)
token_ids = tokens["input_ids"][0].tolist()
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
# Keep only description tokens that already have a meaningful embedding.
valid_token_ids = [tid for tid in token_ids if tid not in new_id_set and 0 <= tid < vocab_size]
if len(valid_token_ids) == 0:
# Fallback: use mean of all existing embeddings
logger.warning_rank0(
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
f"Description for token '{token_str}' contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
base_embedding = fallback_embedding
else:
# Get embeddings of description tokens and average them
token_embeds = model.get_input_embeddings()(valid_token_ids)
valid_ids_tensor = torch.as_tensor(valid_token_ids, device=device, dtype=torch.long)
token_embeds = model.get_input_embeddings()(valid_ids_tensor)
base_embedding = token_embeds.mean(dim=0)
# Add noise if requested (ensure correct device and dtype)
if add_noise:
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
embed_weight[-num_new_tokens + i] = base_embedding + noise
embed_weight[token_id] = base_embedding + noise
else:
embed_weight[-num_new_tokens + i] = base_embedding
embed_weight[token_id] = base_embedding
def _initialize_embeddings(
@@ -118,6 +227,7 @@ def _initialize_embeddings(
new_special_tokens_config: Optional[dict],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
new_token_ids: Optional[list[int]] = None,
) -> None:
"""Single source of truth for embedding initialization.
@@ -130,16 +240,18 @@ def _initialize_embeddings(
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
new_token_ids: Explicit IDs of the newly added tokens (robust to embedding padding).
When ``None``, the init helpers fall back to the last ``num_new_tokens`` rows.
"""
if init_method == "desc_init" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=False
)
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=True
)
else:
if init_method != "noise_init":
@@ -147,20 +259,28 @@ def _initialize_embeddings(
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
)
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
_noisy_mean_initialization(embed_weight, num_new_tokens)
_noisy_mean_initialization(embed_weight, num_new_tokens, token_ids=new_token_ids)
def resize_embedding_layer(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
new_tokens: Optional[Iterable[str]] = None,
new_special_tokens_config: Optional[dict] = None,
init_special_tokens: str = "noise_init",
) -> None:
r"""Resize token embeddings and initialize new tokens.
r"""Resize token embeddings (when needed) and initialize the newly added tokens.
Resizing and initialization are decoupled: even when the tokenizer vocab fits inside
the model's existing (padded) embedding matrix and no resize is triggered, the newly
added tokens still occupy uninitialized rows and must be initialized. We therefore
resolve the explicit row IDs of ``new_tokens`` and always initialize those rows.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_tokens: Iterable of the newly added token strings. Used to locate the exact
embedding rows to initialize, which is robust to pre-existing embedding padding.
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
@@ -175,44 +295,70 @@ def resize_embedding_layer(
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
current_embedding_size = model.get_input_embeddings().weight.size(0)
current_embedding_size = get_embedding_vocab_size(model)
needs_resize = len(tokenizer) > current_embedding_size
if len(tokenizer) > current_embedding_size:
if needs_resize:
if getattr(model, "quantization_method", None):
raise ValueError("Cannot resize embedding layers of a quantized model.")
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
raise ValueError("Current model does not support resizing embedding layers.")
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
# mean_resizing=False preserves the original embedding distribution exactly.
# HuggingFace's default mean_resizing=True re-samples new rows from the mean/covariance
# of existing embeddings, which conflicts with our explicit initialization below.
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64, mean_resizing=False)
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
# Resolve the exact rows of the new tokens. This works whether or not a resize was
# triggered (e.g. tokens added into a model's pre-existing padding zone).
new_token_ids = _resolve_new_token_ids(new_tokens, tokenizer, new_embedding_size)
if num_new_tokens <= 0 and not new_token_ids:
return
if needs_resize:
logger.info_rank0(
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
)
else:
logger.info_rank0(
f"No resize needed (vocab fits in padded embedding {new_embedding_size}); "
f"initializing {len(new_token_ids or [])} new token(s) in place."
)
# Initialize input embeddings
# Initialize input embeddings
_initialize_embeddings(
model.get_input_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
new_token_ids=new_token_ids,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_input_embeddings().weight.data,
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
new_token_ids=new_token_ids,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
if needs_resize:
model.config.vocab_size = new_embedding_size
# Also update the nested text_config for VL models (e.g., Qwen2.5-VL, LLaVA),
# otherwise config.vocab_size and config.text_config.vocab_size become inconsistent.
if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "vocab_size"):
model.config.text_config.vocab_size = new_embedding_size
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@@ -457,9 +457,14 @@ def patch_model(
prepare_valuehead_model(model)
if model_args.resize_vocab:
# Pass the explicit list of newly added tokens so their exact embedding rows can be
# located and initialized, even when they land in a model's pre-existing padding zone.
new_tokens = (model_args.add_tokens or []) + (model_args.add_special_tokens or [])
resize_embedding_layer(
model,
tokenizer,
new_tokens=new_tokens or None,
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
init_special_tokens=model_args.init_special_tokens,
)

View File

@@ -0,0 +1,149 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from llamafactory.model.model_utils.embedding import (
_description_based_initialization,
_existing_embeddings,
_noisy_mean_initialization,
_resolve_new_token_ids,
)
class _StubTokenizer:
"""Minimal tokenizer stub mapping token strings to fixed IDs."""
unk_token_id = 0
def __init__(self, mapping: dict[str, int], desc_ids: list[int] | None = None):
self._mapping = mapping
self._desc_ids = desc_ids or []
def convert_tokens_to_ids(self, token: str) -> int:
return self._mapping.get(token, self.unk_token_id)
def __call__(self, desc, return_tensors=None, add_special_tokens=False):
return {"input_ids": torch.tensor([self._desc_ids], dtype=torch.long)}
class _StubModel:
"""Wraps an embedding matrix so ``get_input_embeddings()`` is a usable lookup."""
def __init__(self, embed_weight: "torch.Tensor"):
self._emb = torch.nn.Embedding.from_pretrained(embed_weight.clone(), freeze=True)
def get_input_embeddings(self):
return self._emb
def test_resolve_new_token_ids_returns_none_without_config():
tokenizer = _StubTokenizer({})
assert _resolve_new_token_ids(None, tokenizer, embed_size=100) is None
assert _resolve_new_token_ids([], tokenizer, embed_size=100) is None
def test_resolve_new_token_ids_filters_invalid_and_dedups():
# "<a>" valid, "<unk_like>" maps to unk_token_id (skipped), "<oob>" out of range (skipped)
tokenizer = _StubTokenizer({"<a>": 10, "<unk_like>": 0, "<oob>": 999, "<b>": 5})
# duplicates and unsorted input -> sorted unique in-range IDs
tokens = ["<a>", "<a>", "<unk_like>", "<oob>", "<b>"]
assert _resolve_new_token_ids(tokens, tokenizer, embed_size=100) == [5, 10]
# passing a dict iterates its keys (config compatibility)
assert _resolve_new_token_ids({"<a>": "desc"}, tokenizer, embed_size=100) == [10]
def test_existing_embeddings_excludes_new_token_ids():
embed_weight = torch.arange(10 * 2, dtype=torch.float32).reshape(10, 2)
# explicit ids take precedence and drop exactly those rows
existing = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=[2, 5])
assert existing.size(0) == 8
# tail fallback when no explicit ids
tail = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=None)
assert torch.allclose(tail, embed_weight[:-3])
# no resize and no ids -> use everything
everything = _existing_embeddings(embed_weight, num_new_tokens=0, new_token_ids=None)
assert torch.allclose(everything, embed_weight)
def test_noisy_mean_initialization_with_token_ids_targets_exact_rows():
"""New tokens placed by explicit IDs must hit those rows, even inside the padding zone."""
torch.manual_seed(0)
vocab_size, embedding_dim = 20, 8
embed_weight = torch.zeros(vocab_size, embedding_dim)
# existing rows carry a constant so the mean is well-defined and non-zero
embed_weight[:16] = 1.0
# num_new_tokens reflects the embedding resize delta (4 padded rows),
# but the real new tokens sit at IDs 16 and 17 (inside what the tail slice would miss/over-cover).
target_ids = [16, 17]
_noisy_mean_initialization(embed_weight, num_new_tokens=4, token_ids=target_ids)
# targeted rows are initialized around the mean (~1.0) and not left at zero
for tid in target_ids:
assert not torch.allclose(embed_weight[tid], torch.zeros(embedding_dim))
assert abs(embed_weight[tid].mean().item() - 1.0) < 0.5
# untouched padding rows (18, 19) must remain zero
assert torch.allclose(embed_weight[18], torch.zeros(embedding_dim))
assert torch.allclose(embed_weight[19], torch.zeros(embedding_dim))
def test_noisy_mean_initialization_tail_fallback():
"""Without token_ids, falls back to the last num_new_tokens rows."""
torch.manual_seed(0)
vocab_size, embedding_dim = 12, 8
embed_weight = torch.zeros(vocab_size, embedding_dim)
embed_weight[:10] = 1.0
_noisy_mean_initialization(embed_weight, num_new_tokens=2, token_ids=None)
# last two rows initialized, earlier rows untouched
assert not torch.allclose(embed_weight[-1], torch.zeros(embedding_dim))
assert not torch.allclose(embed_weight[-2], torch.zeros(embedding_dim))
assert torch.allclose(embed_weight[0], torch.ones(embedding_dim))
def test_description_init_excludes_new_token_ids_from_average():
"""Description tokens that are themselves new (uninitialized) must be excluded.
Reproduces the padding-zone bug: id 17 is a new token and must not pollute the
semantic average for id 16; only the valid existing token (id 5) should be used.
"""
vocab_size, embedding_dim = 20, 4
embed_weight = torch.zeros(vocab_size, embedding_dim)
embed_weight[5] = 3.0 # the only valid description token
# description for "<x>" tokenizes to [5 (existing), 17 (new -> must be skipped)]
tokenizer = _StubTokenizer({"<x>": 16}, desc_ids=[5, 17])
model = _StubModel(embed_weight)
_description_based_initialization(
embed_weight,
num_new_tokens=4,
descriptions={"<x>": "ignored, ids come from the stub"},
tokenizer=tokenizer,
model=model,
new_token_ids=[16, 17],
add_noise=False,
)
# row 16 must equal embedding of id 5 only (3.0), not the (5,17) average (1.5)
assert torch.allclose(embed_weight[16], torch.full((embedding_dim,), 3.0))
if __name__ == "__main__":
import pytest
pytest.main([__file__])