[fix] correctly place new token embeddings when embedding is padded (#10547)

This commit is contained in:
Ximing Xing
2026-06-05 10:47:51 +08:00
committed by GitHub
parent 8a4f6a3da5
commit 0b7aaf8f6a
3 changed files with 341 additions and 41 deletions

View File

@@ -0,0 +1,149 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from llamafactory.model.model_utils.embedding import (
_description_based_initialization,
_existing_embeddings,
_noisy_mean_initialization,
_resolve_new_token_ids,
)
class _StubTokenizer:
"""Minimal tokenizer stub mapping token strings to fixed IDs."""
unk_token_id = 0
def __init__(self, mapping: dict[str, int], desc_ids: list[int] | None = None):
self._mapping = mapping
self._desc_ids = desc_ids or []
def convert_tokens_to_ids(self, token: str) -> int:
return self._mapping.get(token, self.unk_token_id)
def __call__(self, desc, return_tensors=None, add_special_tokens=False):
return {"input_ids": torch.tensor([self._desc_ids], dtype=torch.long)}
class _StubModel:
"""Wraps an embedding matrix so ``get_input_embeddings()`` is a usable lookup."""
def __init__(self, embed_weight: "torch.Tensor"):
self._emb = torch.nn.Embedding.from_pretrained(embed_weight.clone(), freeze=True)
def get_input_embeddings(self):
return self._emb
def test_resolve_new_token_ids_returns_none_without_config():
tokenizer = _StubTokenizer({})
assert _resolve_new_token_ids(None, tokenizer, embed_size=100) is None
assert _resolve_new_token_ids([], tokenizer, embed_size=100) is None
def test_resolve_new_token_ids_filters_invalid_and_dedups():
# "<a>" valid, "<unk_like>" maps to unk_token_id (skipped), "<oob>" out of range (skipped)
tokenizer = _StubTokenizer({"<a>": 10, "<unk_like>": 0, "<oob>": 999, "<b>": 5})
# duplicates and unsorted input -> sorted unique in-range IDs
tokens = ["<a>", "<a>", "<unk_like>", "<oob>", "<b>"]
assert _resolve_new_token_ids(tokens, tokenizer, embed_size=100) == [5, 10]
# passing a dict iterates its keys (config compatibility)
assert _resolve_new_token_ids({"<a>": "desc"}, tokenizer, embed_size=100) == [10]
def test_existing_embeddings_excludes_new_token_ids():
embed_weight = torch.arange(10 * 2, dtype=torch.float32).reshape(10, 2)
# explicit ids take precedence and drop exactly those rows
existing = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=[2, 5])
assert existing.size(0) == 8
# tail fallback when no explicit ids
tail = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=None)
assert torch.allclose(tail, embed_weight[:-3])
# no resize and no ids -> use everything
everything = _existing_embeddings(embed_weight, num_new_tokens=0, new_token_ids=None)
assert torch.allclose(everything, embed_weight)
def test_noisy_mean_initialization_with_token_ids_targets_exact_rows():
"""New tokens placed by explicit IDs must hit those rows, even inside the padding zone."""
torch.manual_seed(0)
vocab_size, embedding_dim = 20, 8
embed_weight = torch.zeros(vocab_size, embedding_dim)
# existing rows carry a constant so the mean is well-defined and non-zero
embed_weight[:16] = 1.0
# num_new_tokens reflects the embedding resize delta (4 padded rows),
# but the real new tokens sit at IDs 16 and 17 (inside what the tail slice would miss/over-cover).
target_ids = [16, 17]
_noisy_mean_initialization(embed_weight, num_new_tokens=4, token_ids=target_ids)
# targeted rows are initialized around the mean (~1.0) and not left at zero
for tid in target_ids:
assert not torch.allclose(embed_weight[tid], torch.zeros(embedding_dim))
assert abs(embed_weight[tid].mean().item() - 1.0) < 0.5
# untouched padding rows (18, 19) must remain zero
assert torch.allclose(embed_weight[18], torch.zeros(embedding_dim))
assert torch.allclose(embed_weight[19], torch.zeros(embedding_dim))
def test_noisy_mean_initialization_tail_fallback():
"""Without token_ids, falls back to the last num_new_tokens rows."""
torch.manual_seed(0)
vocab_size, embedding_dim = 12, 8
embed_weight = torch.zeros(vocab_size, embedding_dim)
embed_weight[:10] = 1.0
_noisy_mean_initialization(embed_weight, num_new_tokens=2, token_ids=None)
# last two rows initialized, earlier rows untouched
assert not torch.allclose(embed_weight[-1], torch.zeros(embedding_dim))
assert not torch.allclose(embed_weight[-2], torch.zeros(embedding_dim))
assert torch.allclose(embed_weight[0], torch.ones(embedding_dim))
def test_description_init_excludes_new_token_ids_from_average():
"""Description tokens that are themselves new (uninitialized) must be excluded.
Reproduces the padding-zone bug: id 17 is a new token and must not pollute the
semantic average for id 16; only the valid existing token (id 5) should be used.
"""
vocab_size, embedding_dim = 20, 4
embed_weight = torch.zeros(vocab_size, embedding_dim)
embed_weight[5] = 3.0 # the only valid description token
# description for "<x>" tokenizes to [5 (existing), 17 (new -> must be skipped)]
tokenizer = _StubTokenizer({"<x>": 16}, desc_ids=[5, 17])
model = _StubModel(embed_weight)
_description_based_initialization(
embed_weight,
num_new_tokens=4,
descriptions={"<x>": "ignored, ids come from the stub"},
tokenizer=tokenizer,
model=model,
new_token_ids=[16, 17],
add_noise=False,
)
# row 16 must equal embedding of id 5 only (3.0), not the (5,17) average (1.5)
assert torch.allclose(embed_weight[16], torch.full((embedding_dim,), 3.0))
if __name__ == "__main__":
import pytest
pytest.main([__file__])