mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-06-17 04:38:53 +08:00
[fix] correctly place new token embeddings when embedding is padded (#10547)
This commit is contained in:
@@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from collections.abc import Iterable
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
@@ -29,7 +30,81 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
def get_embedding_vocab_size(model: "PreTrainedModel") -> int:
|
||||||
|
r"""Get the vocab size from the input embedding layer.
|
||||||
|
|
||||||
|
Handles DeepSpeed ZeRO-3 parameter sharding by gathering the embedding weight
|
||||||
|
before reading its size.
|
||||||
|
"""
|
||||||
|
embedding = model.get_input_embeddings()
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
import deepspeed # type: ignore
|
||||||
|
|
||||||
|
with deepspeed.zero.GatheredParameters([embedding.weight]):
|
||||||
|
return embedding.weight.size(0)
|
||||||
|
|
||||||
|
return embedding.weight.size(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_new_token_ids(
|
||||||
|
new_tokens: Optional[Iterable[str]],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
embed_size: int,
|
||||||
|
) -> Optional[list[int]]:
|
||||||
|
r"""Resolve the explicit embedding-row IDs of the newly added tokens.
|
||||||
|
|
||||||
|
Relying on ``embed_weight[-num_new_tokens:]`` to locate new tokens is unsafe when
|
||||||
|
the model embedding was already padded beyond the tokenizer vocab (e.g. Qwen2.5-VL
|
||||||
|
has vocab 151665 but embedding 151936). In that case the appended tokens land
|
||||||
|
inside the original padding zone and the tail slice points at the wrong rows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_tokens: Iterable of the newly added token strings.
|
||||||
|
tokenizer: The tokenizer instance.
|
||||||
|
embed_size: Current embedding size (upper bound for valid token IDs).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sorted list of unique, in-range token IDs, or ``None`` when no tokens are
|
||||||
|
given so that callers can fall back to the tail-slice behaviour.
|
||||||
|
"""
|
||||||
|
if not new_tokens:
|
||||||
|
return None
|
||||||
|
|
||||||
|
unk_token_id = getattr(tokenizer, "unk_token_id", None)
|
||||||
|
token_ids: set[int] = set()
|
||||||
|
for token_str in new_tokens:
|
||||||
|
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||||
|
if token_id is None or token_id == unk_token_id or not (0 <= token_id < embed_size):
|
||||||
|
logger.warning_rank0(f"Token '{token_str}' not found or out of range, skipping during init.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
token_ids.add(token_id)
|
||||||
|
|
||||||
|
return sorted(token_ids) or None
|
||||||
|
|
||||||
|
|
||||||
|
def _existing_embeddings(
|
||||||
|
embed_weight: "torch.Tensor", num_new_tokens: int, new_token_ids: Optional[list[int]]
|
||||||
|
) -> "torch.Tensor":
|
||||||
|
"""Return the rows treated as 'existing' embeddings used as the init baseline.
|
||||||
|
|
||||||
|
Prefers excluding the explicit new-token rows (robust to padding). Falls back to
|
||||||
|
dropping the last ``num_new_tokens`` rows when no explicit IDs are available.
|
||||||
|
"""
|
||||||
|
if new_token_ids:
|
||||||
|
mask = torch.ones(embed_weight.size(0), dtype=torch.bool, device=embed_weight.device)
|
||||||
|
mask[torch.as_tensor(new_token_ids, device=embed_weight.device, dtype=torch.long)] = False
|
||||||
|
return embed_weight[mask]
|
||||||
|
|
||||||
|
if num_new_tokens > 0:
|
||||||
|
return embed_weight[:-num_new_tokens]
|
||||||
|
|
||||||
|
return embed_weight
|
||||||
|
|
||||||
|
|
||||||
|
def _noisy_mean_initialization(
|
||||||
|
embed_weight: "torch.Tensor", num_new_tokens: int, token_ids: Optional[list[int]] = None
|
||||||
|
) -> None:
|
||||||
"""Initialize new token embeddings with mean + Gaussian noise.
|
"""Initialize new token embeddings with mean + Gaussian noise.
|
||||||
|
|
||||||
This is the default initialization method used by LlamaFactory.
|
This is the default initialization method used by LlamaFactory.
|
||||||
@@ -37,12 +112,23 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
|
|||||||
Args:
|
Args:
|
||||||
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
||||||
num_new_tokens: Number of new tokens added at the end of the embedding matrix
|
num_new_tokens: Number of new tokens added at the end of the embedding matrix
|
||||||
|
token_ids: Explicit token IDs to initialize. When provided, these exact rows are
|
||||||
|
written (robust to padding). When ``None``, falls back to the last
|
||||||
|
``num_new_tokens`` rows.
|
||||||
"""
|
"""
|
||||||
embedding_dim = embed_weight.size(1)
|
embedding_dim = embed_weight.size(1)
|
||||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
avg_weight = _existing_embeddings(embed_weight, num_new_tokens, token_ids).mean(dim=0, keepdim=True)
|
||||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
|
||||||
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
if token_ids:
|
||||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
noise_weight = torch.empty(
|
||||||
|
len(token_ids), embedding_dim, device=embed_weight.device, dtype=embed_weight.dtype
|
||||||
|
)
|
||||||
|
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||||
|
embed_weight[token_ids] = avg_weight + noise_weight
|
||||||
|
else:
|
||||||
|
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||||
|
noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
|
||||||
|
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||||
|
|
||||||
|
|
||||||
def _description_based_initialization(
|
def _description_based_initialization(
|
||||||
@@ -51,6 +137,7 @@ def _description_based_initialization(
|
|||||||
descriptions: dict[str, str],
|
descriptions: dict[str, str],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
|
new_token_ids: Optional[list[int]] = None,
|
||||||
add_noise: bool = False,
|
add_noise: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize new token embeddings based on textual descriptions.
|
"""Initialize new token embeddings based on textual descriptions.
|
||||||
@@ -61,6 +148,9 @@ def _description_based_initialization(
|
|||||||
3. Averages them to initialize the new token's embedding
|
3. Averages them to initialize the new token's embedding
|
||||||
4. Optionally adds Gaussian noise
|
4. Optionally adds Gaussian noise
|
||||||
|
|
||||||
|
New tokens are placed by their resolved token ID rather than by tail slicing,
|
||||||
|
so the initialization is correct even when the embedding matrix was padded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
||||||
num_new_tokens: Number of new tokens added
|
num_new_tokens: Number of new tokens added
|
||||||
@@ -68,6 +158,8 @@ def _description_based_initialization(
|
|||||||
e.g., {"<think>": "A token representing reasoning process"}
|
e.g., {"<think>": "A token representing reasoning process"}
|
||||||
tokenizer: The tokenizer instance
|
tokenizer: The tokenizer instance
|
||||||
model: The model instance (used to get input embeddings)
|
model: The model instance (used to get input embeddings)
|
||||||
|
new_token_ids: IDs of all newly added tokens. Used to exclude not-yet-initialized
|
||||||
|
rows when averaging description-token embeddings (robust to embedding padding).
|
||||||
add_noise: Whether to add Gaussian noise to the initialization
|
add_noise: Whether to add Gaussian noise to the initialization
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@@ -77,38 +169,55 @@ def _description_based_initialization(
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
embedding_dim = embed_weight.size(1)
|
embedding_dim = embed_weight.size(1)
|
||||||
|
vocab_size = embed_weight.size(0)
|
||||||
|
unk_token_id = getattr(tokenizer, "unk_token_id", None)
|
||||||
|
device = embed_weight.device
|
||||||
|
|
||||||
|
# The set of rows that are NOT yet initialized (the newly added tokens). Description
|
||||||
|
# tokens that fall into this set must be excluded, otherwise we would average garbage.
|
||||||
|
# `num_new_tokens` (the padded resize delta) is NOT a reliable boundary, so rely on
|
||||||
|
# the explicit IDs, falling back to resolving them from the description keys.
|
||||||
|
if new_token_ids is None:
|
||||||
|
new_token_ids = _resolve_new_token_ids(descriptions.keys(), tokenizer, vocab_size)
|
||||||
|
|
||||||
|
new_id_set = set(new_token_ids or [])
|
||||||
|
fallback_embedding = _existing_embeddings(embed_weight, num_new_tokens, new_token_ids).mean(dim=0)
|
||||||
|
|
||||||
|
for token_str, desc in descriptions.items():
|
||||||
|
# Resolve token ID for correct placement (robust to embedding padding)
|
||||||
|
token_id = tokenizer.convert_tokens_to_ids(token_str)
|
||||||
|
if token_id is None or token_id == unk_token_id or not (0 <= token_id < vocab_size):
|
||||||
|
logger.warning_rank0(f"desc_init: token '{token_str}' not found or out of range, skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
for i, desc in enumerate(descriptions.values()):
|
|
||||||
# Tokenize description text
|
# Tokenize description text
|
||||||
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
|
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
token_ids = tokens["input_ids"][0]
|
token_ids = tokens["input_ids"][0].tolist()
|
||||||
# Move to the same device as embed_weight
|
|
||||||
device = embed_weight.device
|
|
||||||
token_ids = token_ids.to(device)
|
|
||||||
|
|
||||||
# Filter out new tokens (they don't have valid embeddings yet)
|
# Keep only description tokens that already have a meaningful embedding.
|
||||||
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
|
valid_token_ids = [tid for tid in token_ids if tid not in new_id_set and 0 <= tid < vocab_size]
|
||||||
|
|
||||||
if len(valid_token_ids) == 0:
|
if len(valid_token_ids) == 0:
|
||||||
# Fallback: use mean of all existing embeddings
|
# Fallback: use mean of all existing embeddings
|
||||||
logger.warning_rank0(
|
logger.warning_rank0(
|
||||||
f"Description for token {i + 1}/{num_new_tokens} contains no valid tokens. "
|
f"Description for token '{token_str}' contains no valid tokens. "
|
||||||
"Using mean of existing embeddings."
|
"Using mean of existing embeddings."
|
||||||
)
|
)
|
||||||
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
|
base_embedding = fallback_embedding
|
||||||
else:
|
else:
|
||||||
# Get embeddings of description tokens and average them
|
# Get embeddings of description tokens and average them
|
||||||
token_embeds = model.get_input_embeddings()(valid_token_ids)
|
valid_ids_tensor = torch.as_tensor(valid_token_ids, device=device, dtype=torch.long)
|
||||||
|
token_embeds = model.get_input_embeddings()(valid_ids_tensor)
|
||||||
base_embedding = token_embeds.mean(dim=0)
|
base_embedding = token_embeds.mean(dim=0)
|
||||||
|
|
||||||
# Add noise if requested (ensure correct device and dtype)
|
# Add noise if requested (ensure correct device and dtype)
|
||||||
if add_noise:
|
if add_noise:
|
||||||
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
|
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
|
||||||
embed_weight[-num_new_tokens + i] = base_embedding + noise
|
embed_weight[token_id] = base_embedding + noise
|
||||||
else:
|
else:
|
||||||
embed_weight[-num_new_tokens + i] = base_embedding
|
embed_weight[token_id] = base_embedding
|
||||||
|
|
||||||
|
|
||||||
def _initialize_embeddings(
|
def _initialize_embeddings(
|
||||||
@@ -118,6 +227,7 @@ def _initialize_embeddings(
|
|||||||
new_special_tokens_config: Optional[dict],
|
new_special_tokens_config: Optional[dict],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
|
new_token_ids: Optional[list[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Single source of truth for embedding initialization.
|
"""Single source of truth for embedding initialization.
|
||||||
|
|
||||||
@@ -130,16 +240,18 @@ def _initialize_embeddings(
|
|||||||
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
|
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
|
||||||
tokenizer: The tokenizer instance
|
tokenizer: The tokenizer instance
|
||||||
model: The model instance
|
model: The model instance
|
||||||
|
new_token_ids: Explicit IDs of the newly added tokens (robust to embedding padding).
|
||||||
|
When ``None``, the init helpers fall back to the last ``num_new_tokens`` rows.
|
||||||
"""
|
"""
|
||||||
if init_method == "desc_init" and new_special_tokens_config:
|
if init_method == "desc_init" and new_special_tokens_config:
|
||||||
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
|
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
|
||||||
_description_based_initialization(
|
_description_based_initialization(
|
||||||
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
|
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=False
|
||||||
)
|
)
|
||||||
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
|
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
|
||||||
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
|
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
|
||||||
_description_based_initialization(
|
_description_based_initialization(
|
||||||
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
|
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, new_token_ids, add_noise=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if init_method != "noise_init":
|
if init_method != "noise_init":
|
||||||
@@ -147,20 +259,28 @@ def _initialize_embeddings(
|
|||||||
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
|
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
|
||||||
)
|
)
|
||||||
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
|
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
|
||||||
_noisy_mean_initialization(embed_weight, num_new_tokens)
|
_noisy_mean_initialization(embed_weight, num_new_tokens, token_ids=new_token_ids)
|
||||||
|
|
||||||
|
|
||||||
def resize_embedding_layer(
|
def resize_embedding_layer(
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
new_tokens: Optional[Iterable[str]] = None,
|
||||||
new_special_tokens_config: Optional[dict] = None,
|
new_special_tokens_config: Optional[dict] = None,
|
||||||
init_special_tokens: str = "noise_init",
|
init_special_tokens: str = "noise_init",
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""Resize token embeddings and initialize new tokens.
|
r"""Resize token embeddings (when needed) and initialize the newly added tokens.
|
||||||
|
|
||||||
|
Resizing and initialization are decoupled: even when the tokenizer vocab fits inside
|
||||||
|
the model's existing (padded) embedding matrix and no resize is triggered, the newly
|
||||||
|
added tokens still occupy uninitialized rows and must be initialized. We therefore
|
||||||
|
resolve the explicit row IDs of ``new_tokens`` and always initialize those rows.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: The model to resize
|
model: The model to resize
|
||||||
tokenizer: The tokenizer (used to get target vocab size)
|
tokenizer: The tokenizer (used to get target vocab size)
|
||||||
|
new_tokens: Iterable of the newly added token strings. Used to locate the exact
|
||||||
|
embedding rows to initialize, which is robust to pre-existing embedding padding.
|
||||||
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
|
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
|
||||||
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
|
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
|
||||||
"""
|
"""
|
||||||
@@ -175,44 +295,70 @@ def resize_embedding_layer(
|
|||||||
else:
|
else:
|
||||||
context_maybe_zero3 = nullcontext()
|
context_maybe_zero3 = nullcontext()
|
||||||
|
|
||||||
with context_maybe_zero3:
|
current_embedding_size = get_embedding_vocab_size(model)
|
||||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
needs_resize = len(tokenizer) > current_embedding_size
|
||||||
|
|
||||||
if len(tokenizer) > current_embedding_size:
|
if needs_resize:
|
||||||
if getattr(model, "quantization_method", None):
|
if getattr(model, "quantization_method", None):
|
||||||
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||||
|
|
||||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||||
raise ValueError("Current model does not support resizing embedding layers.")
|
raise ValueError("Current model does not support resizing embedding layers.")
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
# mean_resizing=False preserves the original embedding distribution exactly.
|
||||||
with context_maybe_zero3:
|
# HuggingFace's default mean_resizing=True re-samples new rows from the mean/covariance
|
||||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
# of existing embeddings, which conflicts with our explicit initialization below.
|
||||||
num_new_tokens = new_embedding_size - current_embedding_size
|
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64, mean_resizing=False)
|
||||||
|
|
||||||
|
with context_maybe_zero3:
|
||||||
|
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
num_new_tokens = new_embedding_size - current_embedding_size
|
||||||
|
|
||||||
|
# Resolve the exact rows of the new tokens. This works whether or not a resize was
|
||||||
|
# triggered (e.g. tokens added into a model's pre-existing padding zone).
|
||||||
|
new_token_ids = _resolve_new_token_ids(new_tokens, tokenizer, new_embedding_size)
|
||||||
|
|
||||||
|
if num_new_tokens <= 0 and not new_token_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
if needs_resize:
|
||||||
logger.info_rank0(
|
logger.info_rank0(
|
||||||
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
|
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.info_rank0(
|
||||||
|
f"No resize needed (vocab fits in padded embedding {new_embedding_size}); "
|
||||||
|
f"initializing {len(new_token_ids or [])} new token(s) in place."
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize input embeddings
|
# Initialize input embeddings
|
||||||
|
_initialize_embeddings(
|
||||||
|
model.get_input_embeddings().weight.data,
|
||||||
|
num_new_tokens,
|
||||||
|
init_special_tokens,
|
||||||
|
new_special_tokens_config,
|
||||||
|
tokenizer,
|
||||||
|
model,
|
||||||
|
new_token_ids=new_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize output embeddings if not tied
|
||||||
|
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||||
_initialize_embeddings(
|
_initialize_embeddings(
|
||||||
model.get_input_embeddings().weight.data,
|
model.get_output_embeddings().weight.data,
|
||||||
num_new_tokens,
|
num_new_tokens,
|
||||||
init_special_tokens,
|
init_special_tokens,
|
||||||
new_special_tokens_config,
|
new_special_tokens_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
model,
|
model,
|
||||||
|
new_token_ids=new_token_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize output embeddings if not tied
|
if needs_resize:
|
||||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
|
||||||
_initialize_embeddings(
|
|
||||||
model.get_output_embeddings().weight.data,
|
|
||||||
num_new_tokens,
|
|
||||||
init_special_tokens,
|
|
||||||
new_special_tokens_config,
|
|
||||||
tokenizer,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
|
|
||||||
model.config.vocab_size = new_embedding_size
|
model.config.vocab_size = new_embedding_size
|
||||||
|
# Also update the nested text_config for VL models (e.g., Qwen2.5-VL, LLaVA),
|
||||||
|
# otherwise config.vocab_size and config.text_config.vocab_size become inconsistent.
|
||||||
|
if hasattr(model.config, "text_config") and hasattr(model.config.text_config, "vocab_size"):
|
||||||
|
model.config.text_config.vocab_size = new_embedding_size
|
||||||
|
|
||||||
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||||
|
|||||||
@@ -457,9 +457,14 @@ def patch_model(
|
|||||||
prepare_valuehead_model(model)
|
prepare_valuehead_model(model)
|
||||||
|
|
||||||
if model_args.resize_vocab:
|
if model_args.resize_vocab:
|
||||||
|
# Pass the explicit list of newly added tokens so their exact embedding rows can be
|
||||||
|
# located and initialized, even when they land in a model's pre-existing padding zone.
|
||||||
|
new_tokens = (model_args.add_tokens or []) + (model_args.add_special_tokens or [])
|
||||||
|
|
||||||
resize_embedding_layer(
|
resize_embedding_layer(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
new_tokens=new_tokens or None,
|
||||||
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
|
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
|
||||||
init_special_tokens=model_args.init_special_tokens,
|
init_special_tokens=model_args.init_special_tokens,
|
||||||
)
|
)
|
||||||
|
|||||||
149
tests/model/model_utils/test_embedding.py
Normal file
149
tests/model/model_utils/test_embedding.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from llamafactory.model.model_utils.embedding import (
|
||||||
|
_description_based_initialization,
|
||||||
|
_existing_embeddings,
|
||||||
|
_noisy_mean_initialization,
|
||||||
|
_resolve_new_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubTokenizer:
|
||||||
|
"""Minimal tokenizer stub mapping token strings to fixed IDs."""
|
||||||
|
|
||||||
|
unk_token_id = 0
|
||||||
|
|
||||||
|
def __init__(self, mapping: dict[str, int], desc_ids: list[int] | None = None):
|
||||||
|
self._mapping = mapping
|
||||||
|
self._desc_ids = desc_ids or []
|
||||||
|
|
||||||
|
def convert_tokens_to_ids(self, token: str) -> int:
|
||||||
|
return self._mapping.get(token, self.unk_token_id)
|
||||||
|
|
||||||
|
def __call__(self, desc, return_tensors=None, add_special_tokens=False):
|
||||||
|
return {"input_ids": torch.tensor([self._desc_ids], dtype=torch.long)}
|
||||||
|
|
||||||
|
|
||||||
|
class _StubModel:
|
||||||
|
"""Wraps an embedding matrix so ``get_input_embeddings()`` is a usable lookup."""
|
||||||
|
|
||||||
|
def __init__(self, embed_weight: "torch.Tensor"):
|
||||||
|
self._emb = torch.nn.Embedding.from_pretrained(embed_weight.clone(), freeze=True)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self._emb
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_new_token_ids_returns_none_without_config():
|
||||||
|
tokenizer = _StubTokenizer({})
|
||||||
|
assert _resolve_new_token_ids(None, tokenizer, embed_size=100) is None
|
||||||
|
assert _resolve_new_token_ids([], tokenizer, embed_size=100) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_new_token_ids_filters_invalid_and_dedups():
|
||||||
|
# "<a>" valid, "<unk_like>" maps to unk_token_id (skipped), "<oob>" out of range (skipped)
|
||||||
|
tokenizer = _StubTokenizer({"<a>": 10, "<unk_like>": 0, "<oob>": 999, "<b>": 5})
|
||||||
|
# duplicates and unsorted input -> sorted unique in-range IDs
|
||||||
|
tokens = ["<a>", "<a>", "<unk_like>", "<oob>", "<b>"]
|
||||||
|
assert _resolve_new_token_ids(tokens, tokenizer, embed_size=100) == [5, 10]
|
||||||
|
# passing a dict iterates its keys (config compatibility)
|
||||||
|
assert _resolve_new_token_ids({"<a>": "desc"}, tokenizer, embed_size=100) == [10]
|
||||||
|
|
||||||
|
|
||||||
|
def test_existing_embeddings_excludes_new_token_ids():
|
||||||
|
embed_weight = torch.arange(10 * 2, dtype=torch.float32).reshape(10, 2)
|
||||||
|
# explicit ids take precedence and drop exactly those rows
|
||||||
|
existing = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=[2, 5])
|
||||||
|
assert existing.size(0) == 8
|
||||||
|
# tail fallback when no explicit ids
|
||||||
|
tail = _existing_embeddings(embed_weight, num_new_tokens=3, new_token_ids=None)
|
||||||
|
assert torch.allclose(tail, embed_weight[:-3])
|
||||||
|
# no resize and no ids -> use everything
|
||||||
|
everything = _existing_embeddings(embed_weight, num_new_tokens=0, new_token_ids=None)
|
||||||
|
assert torch.allclose(everything, embed_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def test_noisy_mean_initialization_with_token_ids_targets_exact_rows():
|
||||||
|
"""New tokens placed by explicit IDs must hit those rows, even inside the padding zone."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vocab_size, embedding_dim = 20, 8
|
||||||
|
embed_weight = torch.zeros(vocab_size, embedding_dim)
|
||||||
|
# existing rows carry a constant so the mean is well-defined and non-zero
|
||||||
|
embed_weight[:16] = 1.0
|
||||||
|
|
||||||
|
# num_new_tokens reflects the embedding resize delta (4 padded rows),
|
||||||
|
# but the real new tokens sit at IDs 16 and 17 (inside what the tail slice would miss/over-cover).
|
||||||
|
target_ids = [16, 17]
|
||||||
|
_noisy_mean_initialization(embed_weight, num_new_tokens=4, token_ids=target_ids)
|
||||||
|
|
||||||
|
# targeted rows are initialized around the mean (~1.0) and not left at zero
|
||||||
|
for tid in target_ids:
|
||||||
|
assert not torch.allclose(embed_weight[tid], torch.zeros(embedding_dim))
|
||||||
|
assert abs(embed_weight[tid].mean().item() - 1.0) < 0.5
|
||||||
|
|
||||||
|
# untouched padding rows (18, 19) must remain zero
|
||||||
|
assert torch.allclose(embed_weight[18], torch.zeros(embedding_dim))
|
||||||
|
assert torch.allclose(embed_weight[19], torch.zeros(embedding_dim))
|
||||||
|
|
||||||
|
|
||||||
|
def test_noisy_mean_initialization_tail_fallback():
|
||||||
|
"""Without token_ids, falls back to the last num_new_tokens rows."""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
vocab_size, embedding_dim = 12, 8
|
||||||
|
embed_weight = torch.zeros(vocab_size, embedding_dim)
|
||||||
|
embed_weight[:10] = 1.0
|
||||||
|
|
||||||
|
_noisy_mean_initialization(embed_weight, num_new_tokens=2, token_ids=None)
|
||||||
|
|
||||||
|
# last two rows initialized, earlier rows untouched
|
||||||
|
assert not torch.allclose(embed_weight[-1], torch.zeros(embedding_dim))
|
||||||
|
assert not torch.allclose(embed_weight[-2], torch.zeros(embedding_dim))
|
||||||
|
assert torch.allclose(embed_weight[0], torch.ones(embedding_dim))
|
||||||
|
|
||||||
|
|
||||||
|
def test_description_init_excludes_new_token_ids_from_average():
|
||||||
|
"""Description tokens that are themselves new (uninitialized) must be excluded.
|
||||||
|
|
||||||
|
Reproduces the padding-zone bug: id 17 is a new token and must not pollute the
|
||||||
|
semantic average for id 16; only the valid existing token (id 5) should be used.
|
||||||
|
"""
|
||||||
|
vocab_size, embedding_dim = 20, 4
|
||||||
|
embed_weight = torch.zeros(vocab_size, embedding_dim)
|
||||||
|
embed_weight[5] = 3.0 # the only valid description token
|
||||||
|
|
||||||
|
# description for "<x>" tokenizes to [5 (existing), 17 (new -> must be skipped)]
|
||||||
|
tokenizer = _StubTokenizer({"<x>": 16}, desc_ids=[5, 17])
|
||||||
|
model = _StubModel(embed_weight)
|
||||||
|
|
||||||
|
_description_based_initialization(
|
||||||
|
embed_weight,
|
||||||
|
num_new_tokens=4,
|
||||||
|
descriptions={"<x>": "ignored, ids come from the stub"},
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
model=model,
|
||||||
|
new_token_ids=[16, 17],
|
||||||
|
add_noise=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# row 16 must equal embedding of id 5 only (3.0), not the (5,17) average (1.5)
|
||||||
|
assert torch.allclose(embed_weight[16], torch.full((embedding_dim,), 3.0))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.main([__file__])
|
||||||
Reference in New Issue
Block a user