[model] adds semantic initialization support for special tokens (#9267)

Co-authored-by: ximingxing <ximingxing@tencent.com>
This commit is contained in:
Ximing Xing 2025-10-14 17:00:48 +08:00 committed by GitHub
parent 3dbca4b533
commit c867e28093
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 264 additions and 7 deletions

View File

@ -0,0 +1,25 @@
# SVG Container Tags
"<|START_OF_SVG|>": "Marks the beginning of an SVG document"
"<|END_OF_SVG|>": "Marks the end of an SVG document"
# SVG Group Tags
"<|start_of_g|>": "Begins a group element in SVG for organizing related shapes"
"<|end_of_g|>": "Ends a group element"
# SVG Shape Tags
"<|start_of_rect|>": "Begins a rectangle shape with width and height attributes"
"<|end_of_rect|>": "Ends a rectangle shape definition"
"<|start_of_circle|>": "Begins a circular shape with radius attribute"
"<|end_of_circle|>": "Ends a circular shape definition"
"<|start_of_path|>": "Begins a path element for drawing custom vector graphics"
"<|end_of_path|>": "Ends a path element definition"
"<|start_of_ellipse|>": "Begins an ellipse shape with x and y radii"
"<|end_of_ellipse|>": "Ends an ellipse shape definition"
# SVG Text Tags
"<|start_of_text|>": "Begins a text element for rendering text content"
"<|end_of_text|>": "Ends a text element"
# SVG Style Tags
"<|start_of_style|>": "Begins a style definition block for CSS styling"
"<|end_of_style|>": "Ends a style definition block"

View File

@ -22,9 +22,12 @@ from typing import Any, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from omegaconf import OmegaConf
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger
logger = get_logger(__name__)
@dataclass
class BaseModelArguments:
@ -75,6 +78,28 @@ class BaseModelArguments:
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
new_special_tokens_config: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to YAML config with special token descriptions for semantic initialization. "
"If set, this takes precedence over add_special_tokens. "
"YAML format: {'<token>': 'description text', ...}"
)
},
)
init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
default="noise_init",
metadata={
"help": (
"Initialization method for new special tokens: "
"'noise_init' (default, random noise around mean), "
"'desc_init' (semantic initialization from descriptions), "
"'desc_init_w_noise' (semantic + random noise). "
"Note: 'desc_init' methods require new_special_tokens_config."
)
},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
@ -185,8 +210,63 @@ class BaseModelArguments:
if self.add_tokens is not None: # support multiple tokens
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
if self.add_special_tokens is not None: # support multiple special tokens
# Process special tokens with priority: new_special_tokens_config > add_special_tokens
if self.new_special_tokens_config is not None:
# Priority 1: Load from YAML config (extracts both tokens and descriptions)
try:
cfg = OmegaConf.load(self.new_special_tokens_config)
token_descriptions = OmegaConf.to_container(cfg)
if not isinstance(token_descriptions, dict):
raise ValueError(
f"YAML config must be a dictionary mapping tokens to descriptions. "
f"Got: {type(token_descriptions)}"
)
# Extract token list from config keys
extracted_tokens = list(token_descriptions.keys())
# Warn if both are set
if self.add_special_tokens is not None:
logger.warning_rank0(
"Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
f"Using tokens from config: {extracted_tokens}"
)
# Override add_special_tokens with extracted tokens (as list)
self.add_special_tokens = extracted_tokens
# Store descriptions internally for later use (internal attribute)
self._special_token_descriptions = token_descriptions
logger.info_rank0(
f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
f"{self.new_special_tokens_config}"
)
except Exception as e:
logger.error_rank0(
f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
)
raise
elif self.add_special_tokens is not None:
# Priority 2: Use simple comma-separated string (no descriptions)
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
self._special_token_descriptions = None
else:
# No special tokens to add
self._special_token_descriptions = None
# Validate init method
if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
if self._special_token_descriptions is None:
logger.warning_rank0(
f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
"Falling back to 'noise_init'"
)
self.init_special_tokens = "noise_init"
@dataclass

View File

@ -14,7 +14,7 @@
import math
from contextlib import nullcontext
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
@ -30,6 +30,14 @@ logger = logging.get_logger(__name__)
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
"""Initialize new token embeddings with mean + Gaussian noise.
This is the default initialization method used by LlamaFactory.
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added at the end of the embedding matrix
"""
embedding_dim = embed_weight.size(1)
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
r"""Resize token embeddings."""
def _description_based_initialization(
embed_weight: "torch.Tensor",
num_new_tokens: int,
descriptions: dict[str, str],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
add_noise: bool = False,
) -> None:
"""Initialize new token embeddings based on textual descriptions.
For each new token, this function:
1. Tokenizes its description text
2. Gets embeddings of the description tokens
3. Averages them to initialize the new token's embedding
4. Optionally adds Gaussian noise
Args:
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
num_new_tokens: Number of new tokens added
descriptions: Dict mapping token string to its description text
e.g., {"<think>": "A token representing reasoning process"}
tokenizer: The tokenizer instance
model: The model instance (used to get input embeddings)
add_noise: Whether to add Gaussian noise to the initialization
Example:
descriptions = {
"<|START_OF_SVG|>": "Marks the beginning of an SVG document",
"<|END_OF_SVG|>": "Marks the end of an SVG document"
}
"""
embedding_dim = embed_weight.size(1)
for i, desc in enumerate(descriptions.values()):
# Tokenize description text
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
token_ids = tokens["input_ids"][0]
# Move to the same device as embed_weight
device = embed_weight.device
token_ids = token_ids.to(device)
# Filter out new tokens (they don't have valid embeddings yet)
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
if len(valid_token_ids) == 0:
# Fallback: use mean of all existing embeddings
logger.warning_rank0(
f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. "
"Using mean of existing embeddings."
)
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
else:
# Get embeddings of description tokens and average them
token_embeds = model.get_input_embeddings()(valid_token_ids)
base_embedding = token_embeds.mean(dim=0)
# Add noise if requested (ensure correct device and dtype)
if add_noise:
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
embed_weight[-num_new_tokens + i] = base_embedding + noise
else:
embed_weight[-num_new_tokens + i] = base_embedding
def _initialize_embeddings(
embed_weight: "torch.Tensor",
num_new_tokens: int,
init_method: str,
new_special_tokens_config: Optional[dict],
tokenizer: "PreTrainedTokenizer",
model: "PreTrainedModel",
) -> None:
"""Single source of truth for embedding initialization.
This function selects the appropriate initialization method and applies it.
Args:
embed_weight: The embedding weight matrix to initialize
num_new_tokens: Number of new tokens added
init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
tokenizer: The tokenizer instance
model: The model instance
"""
if init_method == "desc_init" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
)
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
_description_based_initialization(
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
)
else:
if init_method != "noise_init":
logger.warning_rank0(
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
)
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
_noisy_mean_initialization(embed_weight, num_new_tokens)
def resize_embedding_layer(
model: "PreTrainedModel",
tokenizer: "PreTrainedTokenizer",
new_special_tokens_config: Optional[dict] = None,
init_special_tokens: str = "noise_init",
) -> None:
r"""Resize token embeddings and initialize new tokens.
Args:
model: The model to resize
tokenizer: The tokenizer (used to get target vocab size)
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
"""
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
@ -64,8 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
with context_maybe_zero3:
new_embedding_size = model.get_input_embeddings().weight.size(0)
num_new_tokens = new_embedding_size - current_embedding_size
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
logger.info_rank0(
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
)
# Initialize input embeddings
_initialize_embeddings(
model.get_input_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
# Initialize output embeddings if not tied
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
_initialize_embeddings(
model.get_output_embeddings().weight.data,
num_new_tokens,
init_special_tokens,
new_special_tokens_config,
tokenizer,
model,
)
model.config.vocab_size = new_embedding_size
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")

View File

@ -175,7 +175,12 @@ def patch_model(
prepare_valuehead_model(model)
if model_args.resize_vocab:
resize_embedding_layer(model, tokenizer)
resize_embedding_layer(
model,
tokenizer,
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
init_special_tokens=model_args.init_special_tokens,
)
if is_trainable:
if getattr(model.config, "model_type", None) == "gemma3n":