From c867e280937b71904ce20e5b99d562bf9faf042b Mon Sep 17 00:00:00 2001 From: Ximing Xing Date: Tue, 14 Oct 2025 17:00:48 +0800 Subject: [PATCH] [model] adds semantic initialization support for special tokens (#9267) Co-authored-by: ximingxing --- examples/extras/multi_tokens/tokens_cfg.yaml | 25 +++ src/llamafactory/hparams/model_args.py | 82 ++++++++- .../model/model_utils/embedding.py | 157 +++++++++++++++++- src/llamafactory/model/patcher.py | 7 +- 4 files changed, 264 insertions(+), 7 deletions(-) create mode 100644 examples/extras/multi_tokens/tokens_cfg.yaml diff --git a/examples/extras/multi_tokens/tokens_cfg.yaml b/examples/extras/multi_tokens/tokens_cfg.yaml new file mode 100644 index 00000000..2b2d0c28 --- /dev/null +++ b/examples/extras/multi_tokens/tokens_cfg.yaml @@ -0,0 +1,25 @@ +# SVG Container Tags +"<|START_OF_SVG|>": "Marks the beginning of an SVG document" +"<|END_OF_SVG|>": "Marks the end of an SVG document" + +# SVG Group Tags +"<|start_of_g|>": "Begins a group element in SVG for organizing related shapes" +"<|end_of_g|>": "Ends a group element" + +# SVG Shape Tags +"<|start_of_rect|>": "Begins a rectangle shape with width and height attributes" +"<|end_of_rect|>": "Ends a rectangle shape definition" +"<|start_of_circle|>": "Begins a circular shape with radius attribute" +"<|end_of_circle|>": "Ends a circular shape definition" +"<|start_of_path|>": "Begins a path element for drawing custom vector graphics" +"<|end_of_path|>": "Ends a path element definition" +"<|start_of_ellipse|>": "Begins an ellipse shape with x and y radii" +"<|end_of_ellipse|>": "Ends an ellipse shape definition" + +# SVG Text Tags +"<|start_of_text|>": "Begins a text element for rendering text content" +"<|end_of_text|>": "Ends a text element" + +# SVG Style Tags +"<|start_of_style|>": "Begins a style definition block for CSS styling" +"<|end_of_style|>": "Ends a style definition block" diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 0576b729..7762194a 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -22,9 +22,12 @@ from typing import Any, Literal, Optional, Union import torch from transformers.training_args import _convert_str_dict from typing_extensions import Self +from omegaconf import OmegaConf from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling +from ..extras.logging import get_logger +logger = get_logger(__name__) @dataclass class BaseModelArguments: @@ -75,6 +78,28 @@ class BaseModelArguments: default=None, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, ) + new_special_tokens_config: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to YAML config with special token descriptions for semantic initialization. " + "If set, this takes precedence over add_special_tokens. " + "YAML format: {'': 'description text', ...}" + ) + }, + ) + init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field( + default="noise_init", + metadata={ + "help": ( + "Initialization method for new special tokens: " + "'noise_init' (default, random noise around mean), " + "'desc_init' (semantic initialization from descriptions), " + "'desc_init_w_noise' (semantic + random noise). " + "Note: 'desc_init' methods require new_special_tokens_config." + ) + }, + ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, @@ -185,8 +210,63 @@ class BaseModelArguments: if self.add_tokens is not None: # support multiple tokens self.add_tokens = [token.strip() for token in self.add_tokens.split(",")] - if self.add_special_tokens is not None: # support multiple special tokens + # Process special tokens with priority: new_special_tokens_config > add_special_tokens + if self.new_special_tokens_config is not None: + # Priority 1: Load from YAML config (extracts both tokens and descriptions) + try: + cfg = OmegaConf.load(self.new_special_tokens_config) + token_descriptions = OmegaConf.to_container(cfg) + + if not isinstance(token_descriptions, dict): + raise ValueError( + f"YAML config must be a dictionary mapping tokens to descriptions. " + f"Got: {type(token_descriptions)}" + ) + + # Extract token list from config keys + extracted_tokens = list(token_descriptions.keys()) + + # Warn if both are set + if self.add_special_tokens is not None: + logger.warning_rank0( + "Both 'new_special_tokens_config' and 'add_special_tokens' are set. " + f"Using tokens from config: {extracted_tokens}" + ) + + # Override add_special_tokens with extracted tokens (as list) + self.add_special_tokens = extracted_tokens + + # Store descriptions internally for later use (internal attribute) + self._special_token_descriptions = token_descriptions + + logger.info_rank0( + f"Loaded {len(extracted_tokens)} special tokens with descriptions from: " + f"{self.new_special_tokens_config}" + ) + + except Exception as e: + logger.error_rank0( + f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}" + ) + raise + + elif self.add_special_tokens is not None: + # Priority 2: Use simple comma-separated string (no descriptions) self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")] + self._special_token_descriptions = None + + else: + # No special tokens to add + self._special_token_descriptions = None + + # Validate init method + if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]: + if self._special_token_descriptions is None: + logger.warning_rank0( + f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. " + "Falling back to 'noise_init'" + ) + self.init_special_tokens = "noise_init" @dataclass diff --git a/src/llamafactory/model/model_utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py index 5604d854..ef53a789 100644 --- a/src/llamafactory/model/model_utils/embedding.py +++ b/src/llamafactory/model/model_utils/embedding.py @@ -14,7 +14,7 @@ import math from contextlib import nullcontext -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch from transformers.integrations import is_deepspeed_zero3_enabled @@ -30,6 +30,14 @@ logger = logging.get_logger(__name__) def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: + """Initialize new token embeddings with mean + Gaussian noise. + + This is the default initialization method used by LlamaFactory. + + Args: + embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim]) + num_new_tokens: Number of new tokens added at the end of the embedding matrix + """ embedding_dim = embed_weight.size(1) avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) @@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int embed_weight[-num_new_tokens:] = avg_weight + noise_weight -def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: - r"""Resize token embeddings.""" +def _description_based_initialization( + embed_weight: "torch.Tensor", + num_new_tokens: int, + descriptions: dict[str, str], + tokenizer: "PreTrainedTokenizer", + model: "PreTrainedModel", + add_noise: bool = False, +) -> None: + """Initialize new token embeddings based on textual descriptions. + + For each new token, this function: + 1. Tokenizes its description text + 2. Gets embeddings of the description tokens + 3. Averages them to initialize the new token's embedding + 4. Optionally adds Gaussian noise + + Args: + embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim]) + num_new_tokens: Number of new tokens added + descriptions: Dict mapping token string to its description text + e.g., {"": "A token representing reasoning process"} + tokenizer: The tokenizer instance + model: The model instance (used to get input embeddings) + add_noise: Whether to add Gaussian noise to the initialization + + Example: + descriptions = { + "<|START_OF_SVG|>": "Marks the beginning of an SVG document", + "<|END_OF_SVG|>": "Marks the end of an SVG document" + } + """ + embedding_dim = embed_weight.size(1) + + for i, desc in enumerate(descriptions.values()): + # Tokenize description text + tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False) + + with torch.no_grad(): + token_ids = tokens["input_ids"][0] + # Move to the same device as embed_weight + device = embed_weight.device + token_ids = token_ids.to(device) + + # Filter out new tokens (they don't have valid embeddings yet) + valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)] + + if len(valid_token_ids) == 0: + # Fallback: use mean of all existing embeddings + logger.warning_rank0( + f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. " + "Using mean of existing embeddings." + ) + base_embedding = embed_weight[:-num_new_tokens].mean(dim=0) + else: + # Get embeddings of description tokens and average them + token_embeds = model.get_input_embeddings()(valid_token_ids) + base_embedding = token_embeds.mean(dim=0) + + # Add noise if requested (ensure correct device and dtype) + if add_noise: + noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim)) + embed_weight[-num_new_tokens + i] = base_embedding + noise + else: + embed_weight[-num_new_tokens + i] = base_embedding + + +def _initialize_embeddings( + embed_weight: "torch.Tensor", + num_new_tokens: int, + init_method: str, + new_special_tokens_config: Optional[dict], + tokenizer: "PreTrainedTokenizer", + model: "PreTrainedModel", +) -> None: + """Single source of truth for embedding initialization. + + This function selects the appropriate initialization method and applies it. + + Args: + embed_weight: The embedding weight matrix to initialize + num_new_tokens: Number of new tokens added + init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise') + new_special_tokens_config: Config dict with token descriptions (required for desc_init methods) + tokenizer: The tokenizer instance + model: The model instance + """ + if init_method == "desc_init" and new_special_tokens_config: + logger.info_rank0("Using semantic initialization (desc_init) for new special tokens") + _description_based_initialization( + embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False + ) + elif init_method == "desc_init_w_noise" and new_special_tokens_config: + logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens") + _description_based_initialization( + embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True + ) + else: + if init_method != "noise_init": + logger.warning_rank0( + f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'" + ) + logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens") + _noisy_mean_initialization(embed_weight, num_new_tokens) + + +def resize_embedding_layer( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + new_special_tokens_config: Optional[dict] = None, + init_special_tokens: str = "noise_init", +) -> None: + r"""Resize token embeddings and initialize new tokens. + + Args: + model: The model to resize + tokenizer: The tokenizer (used to get target vocab size) + new_special_tokens_config: Optional dict with token descriptions for semantic initialization + init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise') + """ if is_deepspeed_zero3_enabled(): import deepspeed # type: ignore @@ -64,8 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken with context_maybe_zero3: new_embedding_size = model.get_input_embeddings().weight.size(0) num_new_tokens = new_embedding_size - current_embedding_size - _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) - _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) + logger.info_rank0( + f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)" + ) + + # Initialize input embeddings + _initialize_embeddings( + model.get_input_embeddings().weight.data, + num_new_tokens, + init_special_tokens, + new_special_tokens_config, + tokenizer, + model, + ) + + # Initialize output embeddings if not tied + if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: + _initialize_embeddings( + model.get_output_embeddings().weight.data, + num_new_tokens, + init_special_tokens, + new_special_tokens_config, + tokenizer, + model, + ) model.config.vocab_size = new_embedding_size logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index d0ecf542..d4681f5a 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -175,7 +175,12 @@ def patch_model( prepare_valuehead_model(model) if model_args.resize_vocab: - resize_embedding_layer(model, tokenizer) + resize_embedding_layer( + model, + tokenizer, + new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None), + init_special_tokens=model_args.init_special_tokens, + ) if is_trainable: if getattr(model.config, "model_type", None) == "gemma3n":