From 0b7aaf8f6a624bd89a01a155d4265ec861cbdf38 Mon Sep 17 00:00:00 2001 From: Ximing Xing Date: Fri, 5 Jun 2026 10:47:51 +0800 Subject: [PATCH] [fix] correctly place new token embeddings when embedding is padded (#10547) --- .../model/model_utils/embedding.py | 228 ++++++++++++++---- src/llamafactory/model/patcher.py | 5 + tests/model/model_utils/test_embedding.py | 149 ++++++++++++ 3 files changed, 341 insertions(+), 41 deletions(-) create mode 100644 tests/model/model_utils/test_embedding.py diff --git a/src/llamafactory/model/model_utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py index b503f3b97..058d41be9 100644 --- a/src/llamafactory/model/model_utils/embedding.py +++ b/src/llamafactory/model/model_utils/embedding.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from collections.abc import Iterable from contextlib import nullcontext from typing import TYPE_CHECKING, Optional @@ -29,7 +30,81 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: +def get_embedding_vocab_size(model: "PreTrainedModel") -> int: + r"""Get the vocab size from the input embedding layer. + + Handles DeepSpeed ZeRO-3 parameter sharding by gathering the embedding weight + before reading its size. + """ + embedding = model.get_input_embeddings() + if is_deepspeed_zero3_enabled(): + import deepspeed # type: ignore + + with deepspeed.zero.GatheredParameters([embedding.weight]): + return embedding.weight.size(0) + + return embedding.weight.size(0) + + +def _resolve_new_token_ids( + new_tokens: Optional[Iterable[str]], + tokenizer: "PreTrainedTokenizer", + embed_size: int, +) -> Optional[list[int]]: + r"""Resolve the explicit embedding-row IDs of the newly added tokens. + + Relying on ``embed_weight[-num_new_tokens:]`` to locate new tokens is unsafe when + the model embedding was already padded beyond the tokenizer vocab (e.g. Qwen2.5-VL + has vocab 151665 but embedding 151936). In that case the appended tokens land + inside the original padding zone and the tail slice points at the wrong rows. + + Args: + new_tokens: Iterable of the newly added token strings. + tokenizer: The tokenizer instance. + embed_size: Current embedding size (upper bound for valid token IDs). + + Returns: + A sorted list of unique, in-range token IDs, or ``None`` when no tokens are + given so that callers can fall back to the tail-slice behaviour. + """ + if not new_tokens: + return None + + unk_token_id = getattr(tokenizer, "unk_token_id", None) + token_ids: set[int] = set() + for token_str in new_tokens: + token_id = tokenizer.convert_tokens_to_ids(token_str) + if token_id is None or token_id == unk_token_id or not (0 <= token_id < embed_size): + logger.warning_rank0(f"Token '{token_str}' not found or out of range, skipping during init.") + continue + + token_ids.add(token_id) + + return sorted(token_ids) or None + + +def _existing_embeddings( + embed_weight: "torch.Tensor", num_new_tokens: int, new_token_ids: Optional[list[int]] +) -> "torch.Tensor": + """Return the rows treated as 'existing' embeddings used as the init baseline. + + Prefers excluding the explicit new-token rows (robust to padding). Falls back to + dropping the last ``num_new_tokens`` rows when no explicit IDs are available. + """ + if new_token_ids: + mask = torch.ones(embed_weight.size(0), dtype=torch.bool, device=embed_weight.device) + mask[torch.as_tensor(new_token_ids, device=embed_weight.device, dtype=torch.long)] = False + return embed_weight[mask] + + if num_new_tokens > 0: + return embed_weight[:-num_new_tokens] + + return embed_weight + + +def _noisy_mean_initialization( + embed_weight: "torch.Tensor", num_new_tokens: int, token_ids: Optional[list[int]] = None +) -> None: """Initialize new token embeddings with mean + Gaussian noise. This is the default initialization method used by LlamaFactory. @@ -37,12 +112,23 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int Args: embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim]) num_new_tokens: Number of new tokens added at the end of the embedding matrix + token_ids: Explicit token IDs to initialize. When provided, these exact rows are + written (robust to padding). When ``None``, falls back to the last + ``num_new_tokens`` rows. """ embedding_dim = embed_weight.size(1) - avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) - noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) - noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) - embed_weight[-num_new_tokens:] = avg_weight + noise_weight + avg_weight = _existing_embeddings(embed_weight, num_new_tokens, token_ids).mean(dim=0, keepdim=True) + + if token_ids: + noise_weight = torch.empty( + len(token_ids), embedding_dim, device=embed_weight.device, dtype=embed_weight.dtype + ) + noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) + embed_weight[token_ids] = avg_weight + noise_weight + else: + noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) + noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) + embed_weight[-num_new_tokens:] = avg_weight + noise_weight def _description_based_initialization( @@ -51,6 +137,7 @@ def _description_based_initialization( descriptions: dict[str, str], tokenizer: "PreTrainedTokenizer", model: "PreTrainedModel", + new_token_ids: Optional[list[int]] = None, add_noise: bool = False, ) -> None: """Initialize new token embeddings based on textual descriptions. @@ -61,6 +148,9 @@ def _description_based_initialization( 3. Averages them to initialize the new token's embedding 4. Optionally adds Gaussian noise + New tokens are placed by their resolved token ID rather than by tail slicing, + so the initialization is correct even when the embedding matrix was padded. + Args: embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim]) num_new_tokens: Number of new tokens added @@ -68,6 +158,8 @@ def _description_based_initialization( e.g., {"": "A token representing reasoning process"} tokenizer: The tokenizer instance model: The model instance (used to get input embeddings) + new_token_ids: IDs of all newly added tokens. Used to exclude not-yet-initialized + rows when averaging description-token embeddings (robust to embedding padding). add_noise: Whether to add Gaussian noise to the initialization Example: @@ -77,38 +169,55 @@ def _description_based_initialization( } """ embedding_dim = embed_weight.size(1) + vocab_size = embed_weight.size(0) + unk_token_id = getattr(tokenizer, "unk_token_id", None) + device = embed_weight.device + + # The set of rows that are NOT yet initialized (the newly added tokens). Description + # tokens that fall into this set must be excluded, otherwise we would average garbage. + # `num_new_tokens` (the padded resize delta) is NOT a reliable boundary, so rely on + # the explicit IDs, falling back to resolving them from the description keys. + if new_token_ids is None: + new_token_ids = _resolve_new_token_ids(descriptions.keys(), tokenizer, vocab_size) + + new_id_set = set(new_token_ids or []) + fallback_embedding = _existing_embeddings(embed_weight, num_new_tokens, new_token_ids).mean(dim=0) + + for token_str, desc in descriptions.items(): + # Resolve token ID for correct placement (robust to embedding padding) + token_id = tokenizer.convert_tokens_to_ids(token_str) + if token_id is None or token_id == unk_token_id or not (0 <= token_id < vocab_size): + logger.warning_rank0(f"desc_init: token '{token_str}' not found or out of range, skipping.") + continue - for i, desc in enumerate(descriptions.values()): # Tokenize description text tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False) with torch.no_grad(): - token_ids = tokens["input_ids"][0] - # Move to the same device as embed_weight - device = embed_weight.device - token_ids = token_ids.to(device) + token_ids = tokens["input_ids"][0].tolist() - # Filter out new tokens (they don't have valid embeddings yet) - valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)] + # Keep only description tokens that already have a meaningful embedding. + valid_token_ids = [tid for tid in token_ids if tid not in new_id_set and 0 <= tid < vocab_size] if len(valid_token_ids) == 0: # Fallback: use mean of all existing embeddings logger.warning_rank0( - f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. " + f"Description for token '{token_str}' contains no valid tokens. " "Using mean of existing embeddings." ) - base_embedding = embed_weight[:-num_new_tokens].mean(dim=0) + base_embedding = fallback_embedding else: # Get embeddings of description tokens and average them - token_embeds = model.get_input_embeddings()(valid_token_ids) + valid_ids_tensor = torch.as_tensor(valid_token_ids, device=device, dtype=torch.long) + token_embeds = model.get_input_embeddings()(valid_ids_tensor) base_embedding = token_embeds.mean(dim=0) # Add noise if requested (ensure correct device and dtype) if add_noise: noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim)) - embed_weight[-num_new_tokens + i] = base_embedding + noise + embed_weight[token_id] = base_embedding + noise else: - embed_weight[-num_new_tokens + i] = base_embedding + embed_weight[token_id] = base_embedding def _initialize_embeddings( @@ -118,6 +227,7 @@ def _initialize_embeddings( new_special_tokens_config: Optional[dict], tokenizer: "PreTrainedTokenizer", model: "PreTrainedModel", + new_token_ids: Optional[list[int]] = None, ) -> None: """Single source of truth for embedding initialization. @@ -130,16 +240,18 @@ def _initialize_embeddings( new_special_tokens_config: Config dict with token descriptions (required for desc_init methods) tokenizer: The tokenizer instance model: The model instance + new_token_ids: Explicit IDs of the newly added tokens (robust to embedding padding). + When ``None``, the init helpers fall back to the last ``num_new_tokens`` rows. """ if init_method == "desc_init" and new_special_tokens_config: logger.info_rank0("Using semantic initialization (desc_init) for new special tokens") _description_based_initialization( - embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False + embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=False ) elif init_method == "desc_init_w_noise" and new_special_tokens_config: logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens") _description_based_initialization( - embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True + embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=True ) else: if init_method != "noise_init": @@ -147,20 +259,28 @@ def _initialize_embeddings( f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'" ) logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens") - _noisy_mean_initialization(embed_weight, num_new_tokens) + _noisy_mean_initialization(embed_weight, num_new_tokens, token_ids=new_token_ids) def resize_embedding_layer( model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer", + new_tokens: Optional[Iterable[str]] = None, new_special_tokens_config: Optional[dict] = None, init_special_tokens: str = "noise_init", ) -> None: - r"""Resize token embeddings and initialize new tokens. + r"""Resize token embeddings (when needed) and initialize the newly added tokens. + + Resizing and initialization are decoupled: even when the tokenizer vocab fits inside + the model's existing (padded) embedding matrix and no resize is triggered, the newly + added tokens still occupy uninitialized rows and must be initialized. We therefore + resolve the explicit row IDs of ``new_tokens`` and always initialize those rows. Args: model: The model to resize tokenizer: The tokenizer (used to get target vocab size) + new_tokens: Iterable of the newly added token strings. Used to locate the exact + embedding rows to initialize, which is robust to pre-existing embedding padding. new_special_tokens_config: Optional dict with token descriptions for semantic initialization init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise') """ @@ -175,44 +295,70 @@ def resize_embedding_layer( else: context_maybe_zero3 = nullcontext() - with context_maybe_zero3: - current_embedding_size = model.get_input_embeddings().weight.size(0) + current_embedding_size = get_embedding_vocab_size(model) + needs_resize = len(tokenizer) > current_embedding_size - if len(tokenizer) > current_embedding_size: + if needs_resize: if getattr(model, "quantization_method", None): raise ValueError("Cannot resize embedding layers of a quantized model.") if not isinstance(model.get_output_embeddings(), torch.nn.Linear): raise ValueError("Current model does not support resizing embedding layers.") - model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) - with context_maybe_zero3: - new_embedding_size = model.get_input_embeddings().weight.size(0) - num_new_tokens = new_embedding_size - current_embedding_size + # mean_resizing=False preserves the original embedding distribution exactly. + # HuggingFace's default mean_resizing=True re-samples new rows from the mean/covariance + # of existing embeddings, which conflicts with our explicit initialization below. + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64, mean_resizing=False) + + with context_maybe_zero3: + new_embedding_size = model.get_input_embeddings().weight.size(0) + num_new_tokens = new_embedding_size - current_embedding_size + + # Resolve the exact rows of the new tokens. This works whether or not a resize was + # triggered (e.g. tokens added into a model's pre-existing padding zone). + new_token_ids = _resolve_new_token_ids(new_tokens, tokenizer, new_embedding_size) + + if num_new_tokens <= 0 and not new_token_ids: + return + + if needs_resize: logger.info_rank0( f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)" ) + else: + logger.info_rank0( + f"No resize needed (vocab fits in padded embedding {new_embedding_size}); " + f"initializing {len(new_token_ids or [])} new token(s) in place." + ) - # Initialize input embeddings + # Initialize input embeddings + _initialize_embeddings( + model.get_input_embeddings().weight.data, + num_new_tokens, + init_special_tokens, + new_special_tokens_config, + tokenizer, + model, + new_token_ids=new_token_ids, + ) + + # Initialize output embeddings if not tied + if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: _initialize_embeddings( - model.get_input_embeddings().weight.data, + model.get_output_embeddings().weight.data, num_new_tokens, init_special_tokens, new_special_tokens_config, tokenizer, model, + new_token_ids=new_token_ids, ) - # Initialize output embeddings if not tied - if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: - _initialize_embeddings( - model.get_output_embeddings().weight.data, - num_new_tokens, - init_special_tokens, - new_special_tokens_config, - tokenizer, - model, - ) - + if needs_resize: model.config.vocab_size = new_embedding_size + # Also update the nested text_config for VL models (e.g., Qwen2.5-VL, LLaVA), + # otherwise config.vocab_size and config.text_config.vocab_size become inconsistent. + if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "vocab_size"): + model.config.text_config.vocab_size = new_embedding_size + logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 8bec9526a..c2e0ea21f 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -457,9 +457,14 @@ def patch_model( prepare_valuehead_model(model) if model_args.resize_vocab: + # Pass the explicit list of newly added tokens so their exact embedding rows can be + # located and initialized, even when they land in a model's pre-existing padding zone. + new_tokens = (model_args.add_tokens or []) + (model_args.add_special_tokens or []) + resize_embedding_layer( model, tokenizer, + new_tokens=new_tokens or None, new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None), init_special_tokens=model_args.init_special_tokens, ) diff --git a/tests/model/model_utils/test_embedding.py b/tests/model/model_utils/test_embedding.py new file mode 100644 index 000000000..b7e01e1ad --- /dev/null +++ b/tests/model/model_utils/test_embedding.py @@ -0,0 +1,149 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from llamafactory.model.model_utils.embedding import ( + _description_based_initialization, + _existing_embeddings, + _noisy_mean_initialization, + _resolve_new_token_ids, +) + + +class _StubTokenizer: + """Minimal tokenizer stub mapping token strings to fixed IDs.""" + + unk_token_id = 0 + + def __init__(self, mapping: dict[str, int], desc_ids: list[int] | None = None): + self._mapping = mapping + self._desc_ids = desc_ids or [] + + def convert_tokens_to_ids(self, token: str) -> int: + return self._mapping.get(token, self.unk_token_id) + + def __call__(self, desc, return_tensors=None, add_special_tokens=False): + return {"input_ids": torch.tensor([self._desc_ids], dtype=torch.long)} + + +class _StubModel: + """Wraps an embedding matrix so ``get_input_embeddings()`` is a usable lookup.""" + + def __init__(self, embed_weight: "torch.Tensor"): + self._emb = torch.nn.Embedding.from_pretrained(embed_weight.clone(), freeze=True) + + def get_input_embeddings(self): + return self._emb + + +def test_resolve_new_token_ids_returns_none_without_config(): + tokenizer = _StubTokenizer({}) + assert _resolve_new_token_ids(None, tokenizer, embed_size=100) is None + assert _resolve_new_token_ids([], tokenizer, embed_size=100) is None + + +def test_resolve_new_token_ids_filters_invalid_and_dedups(): + # "" valid, "" maps to unk_token_id (skipped), "" out of range (skipped) + tokenizer = _StubTokenizer({"": 10, "": 0, "": 999, "": 5}) + # duplicates and unsorted input -> sorted unique in-range IDs + tokens = ["", "", "", "", ""] + assert _resolve_new_token_ids(tokens, tokenizer, embed_size=100) == [5, 10] + # passing a dict iterates its keys (config compatibility) + assert _resolve_new_token_ids({"": "desc"}, tokenizer, embed_size=100) == [10] + + +def test_existing_embeddings_excludes_new_token_ids(): + embed_weight = torch.arange(10 * 2, dtype=torch.float32).reshape(10, 2) + # explicit ids take precedence and drop exactly those rows + existing = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=[2, 5]) + assert existing.size(0) == 8 + # tail fallback when no explicit ids + tail = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=None) + assert torch.allclose(tail, embed_weight[:-3]) + # no resize and no ids -> use everything + everything = _existing_embeddings(embed_weight, num_new_tokens=0, new_token_ids=None) + assert torch.allclose(everything, embed_weight) + + +def test_noisy_mean_initialization_with_token_ids_targets_exact_rows(): + """New tokens placed by explicit IDs must hit those rows, even inside the padding zone.""" + torch.manual_seed(0) + vocab_size, embedding_dim = 20, 8 + embed_weight = torch.zeros(vocab_size, embedding_dim) + # existing rows carry a constant so the mean is well-defined and non-zero + embed_weight[:16] = 1.0 + + # num_new_tokens reflects the embedding resize delta (4 padded rows), + # but the real new tokens sit at IDs 16 and 17 (inside what the tail slice would miss/over-cover). + target_ids = [16, 17] + _noisy_mean_initialization(embed_weight, num_new_tokens=4, token_ids=target_ids) + + # targeted rows are initialized around the mean (~1.0) and not left at zero + for tid in target_ids: + assert not torch.allclose(embed_weight[tid], torch.zeros(embedding_dim)) + assert abs(embed_weight[tid].mean().item() - 1.0) < 0.5 + + # untouched padding rows (18, 19) must remain zero + assert torch.allclose(embed_weight[18], torch.zeros(embedding_dim)) + assert torch.allclose(embed_weight[19], torch.zeros(embedding_dim)) + + +def test_noisy_mean_initialization_tail_fallback(): + """Without token_ids, falls back to the last num_new_tokens rows.""" + torch.manual_seed(0) + vocab_size, embedding_dim = 12, 8 + embed_weight = torch.zeros(vocab_size, embedding_dim) + embed_weight[:10] = 1.0 + + _noisy_mean_initialization(embed_weight, num_new_tokens=2, token_ids=None) + + # last two rows initialized, earlier rows untouched + assert not torch.allclose(embed_weight[-1], torch.zeros(embedding_dim)) + assert not torch.allclose(embed_weight[-2], torch.zeros(embedding_dim)) + assert torch.allclose(embed_weight[0], torch.ones(embedding_dim)) + + +def test_description_init_excludes_new_token_ids_from_average(): + """Description tokens that are themselves new (uninitialized) must be excluded. + + Reproduces the padding-zone bug: id 17 is a new token and must not pollute the + semantic average for id 16; only the valid existing token (id 5) should be used. + """ + vocab_size, embedding_dim = 20, 4 + embed_weight = torch.zeros(vocab_size, embedding_dim) + embed_weight[5] = 3.0 # the only valid description token + + # description for "" tokenizes to [5 (existing), 17 (new -> must be skipped)] + tokenizer = _StubTokenizer({"": 16}, desc_ids=[5, 17]) + model = _StubModel(embed_weight) + + _description_based_initialization( + embed_weight, + num_new_tokens=4, + descriptions={"": "ignored, ids come from the stub"}, + tokenizer=tokenizer, + model=model, + new_token_ids=[16, 17], + add_noise=False, + ) + + # row 16 must equal embedding of id 5 only (3.0), not the (5,17) average (1.5) + assert torch.allclose(embed_weight[16], torch.full((embedding_dim,), 3.0)) + + +if __name__ == "__main__": + import pytest + + pytest.main([__file__])