mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[model] adds semantic initialization support for special tokens (#9267)
Co-authored-by: ximingxing <ximingxing@tencent.com>
This commit is contained in:
		
							parent
							
								
									3dbca4b533
								
							
						
					
					
						commit
						c867e28093
					
				
							
								
								
									
										25
									
								
								examples/extras/multi_tokens/tokens_cfg.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								examples/extras/multi_tokens/tokens_cfg.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,25 @@
 | 
			
		||||
# SVG Container Tags
 | 
			
		||||
"<|START_OF_SVG|>": "Marks the beginning of an SVG document"
 | 
			
		||||
"<|END_OF_SVG|>": "Marks the end of an SVG document"
 | 
			
		||||
 | 
			
		||||
# SVG Group Tags
 | 
			
		||||
"<|start_of_g|>": "Begins a group element in SVG for organizing related shapes"
 | 
			
		||||
"<|end_of_g|>": "Ends a group element"
 | 
			
		||||
 | 
			
		||||
# SVG Shape Tags
 | 
			
		||||
"<|start_of_rect|>": "Begins a rectangle shape with width and height attributes"
 | 
			
		||||
"<|end_of_rect|>": "Ends a rectangle shape definition"
 | 
			
		||||
"<|start_of_circle|>": "Begins a circular shape with radius attribute"
 | 
			
		||||
"<|end_of_circle|>": "Ends a circular shape definition"
 | 
			
		||||
"<|start_of_path|>": "Begins a path element for drawing custom vector graphics"
 | 
			
		||||
"<|end_of_path|>": "Ends a path element definition"
 | 
			
		||||
"<|start_of_ellipse|>": "Begins an ellipse shape with x and y radii"
 | 
			
		||||
"<|end_of_ellipse|>": "Ends an ellipse shape definition"
 | 
			
		||||
 | 
			
		||||
# SVG Text Tags
 | 
			
		||||
"<|start_of_text|>": "Begins a text element for rendering text content"
 | 
			
		||||
"<|end_of_text|>": "Ends a text element"
 | 
			
		||||
 | 
			
		||||
# SVG Style Tags
 | 
			
		||||
"<|start_of_style|>": "Begins a style definition block for CSS styling"
 | 
			
		||||
"<|end_of_style|>": "Ends a style definition block"
 | 
			
		||||
@ -22,9 +22,12 @@ from typing import Any, Literal, Optional, Union
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.training_args import _convert_str_dict
 | 
			
		||||
from typing_extensions import Self
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class BaseModelArguments:
 | 
			
		||||
@ -75,6 +78,28 @@ class BaseModelArguments:
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
 | 
			
		||||
    )
 | 
			
		||||
    new_special_tokens_config: Optional[str] = field(
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": (
 | 
			
		||||
                "Path to YAML config with special token descriptions for semantic initialization. "
 | 
			
		||||
                "If set, this takes precedence over add_special_tokens. "
 | 
			
		||||
                "YAML format: {'<token>': 'description text', ...}"
 | 
			
		||||
            )
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
 | 
			
		||||
        default="noise_init",
 | 
			
		||||
        metadata={
 | 
			
		||||
            "help": (
 | 
			
		||||
                "Initialization method for new special tokens: "
 | 
			
		||||
                "'noise_init' (default, random noise around mean), "
 | 
			
		||||
                "'desc_init' (semantic initialization from descriptions), "
 | 
			
		||||
                "'desc_init_w_noise' (semantic + random noise). "
 | 
			
		||||
                "Note: 'desc_init' methods require new_special_tokens_config."
 | 
			
		||||
            )
 | 
			
		||||
        },
 | 
			
		||||
    )
 | 
			
		||||
    model_revision: str = field(
 | 
			
		||||
        default="main",
 | 
			
		||||
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
 | 
			
		||||
@ -185,8 +210,63 @@ class BaseModelArguments:
 | 
			
		||||
        if self.add_tokens is not None:  # support multiple tokens
 | 
			
		||||
            self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
 | 
			
		||||
 | 
			
		||||
        if self.add_special_tokens is not None:  # support multiple special tokens
 | 
			
		||||
        # Process special tokens with priority: new_special_tokens_config > add_special_tokens
 | 
			
		||||
        if self.new_special_tokens_config is not None:
 | 
			
		||||
            # Priority 1: Load from YAML config (extracts both tokens and descriptions)
 | 
			
		||||
            try:
 | 
			
		||||
                cfg = OmegaConf.load(self.new_special_tokens_config)
 | 
			
		||||
                token_descriptions = OmegaConf.to_container(cfg)
 | 
			
		||||
 | 
			
		||||
                if not isinstance(token_descriptions, dict):
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"YAML config must be a dictionary mapping tokens to descriptions. "
 | 
			
		||||
                        f"Got: {type(token_descriptions)}"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                # Extract token list from config keys
 | 
			
		||||
                extracted_tokens = list(token_descriptions.keys())
 | 
			
		||||
 | 
			
		||||
                # Warn if both are set
 | 
			
		||||
                if self.add_special_tokens is not None:
 | 
			
		||||
                    logger.warning_rank0(
 | 
			
		||||
                        "Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
 | 
			
		||||
                        f"Using tokens from config: {extracted_tokens}"
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                # Override add_special_tokens with extracted tokens (as list)
 | 
			
		||||
                self.add_special_tokens = extracted_tokens
 | 
			
		||||
 | 
			
		||||
                # Store descriptions internally for later use (internal attribute)
 | 
			
		||||
                self._special_token_descriptions = token_descriptions
 | 
			
		||||
 | 
			
		||||
                logger.info_rank0(
 | 
			
		||||
                    f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
 | 
			
		||||
                    f"{self.new_special_tokens_config}"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                logger.error_rank0(
 | 
			
		||||
                    f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
 | 
			
		||||
                )
 | 
			
		||||
                raise
 | 
			
		||||
 | 
			
		||||
        elif self.add_special_tokens is not None:
 | 
			
		||||
            # Priority 2: Use simple comma-separated string (no descriptions)
 | 
			
		||||
            self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
 | 
			
		||||
            self._special_token_descriptions = None
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            # No special tokens to add
 | 
			
		||||
            self._special_token_descriptions = None
 | 
			
		||||
 | 
			
		||||
        # Validate init method
 | 
			
		||||
        if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
 | 
			
		||||
            if self._special_token_descriptions is None:
 | 
			
		||||
                logger.warning_rank0(
 | 
			
		||||
                    f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
 | 
			
		||||
                    "Falling back to 'noise_init'"
 | 
			
		||||
                )
 | 
			
		||||
                self.init_special_tokens = "noise_init"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
from contextlib import nullcontext
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
from typing import TYPE_CHECKING, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers.integrations import is_deepspeed_zero3_enabled
 | 
			
		||||
@ -30,6 +30,14 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
 | 
			
		||||
    """Initialize new token embeddings with mean + Gaussian noise.
 | 
			
		||||
 | 
			
		||||
    This is the default initialization method used by LlamaFactory.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
 | 
			
		||||
        num_new_tokens: Number of new tokens added at the end of the embedding matrix
 | 
			
		||||
    """
 | 
			
		||||
    embedding_dim = embed_weight.size(1)
 | 
			
		||||
    avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
 | 
			
		||||
    noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
 | 
			
		||||
@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
 | 
			
		||||
    embed_weight[-num_new_tokens:] = avg_weight + noise_weight
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
 | 
			
		||||
    r"""Resize token embeddings."""
 | 
			
		||||
def _description_based_initialization(
 | 
			
		||||
    embed_weight: "torch.Tensor",
 | 
			
		||||
    num_new_tokens: int,
 | 
			
		||||
    descriptions: dict[str, str],
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    add_noise: bool = False,
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Initialize new token embeddings based on textual descriptions.
 | 
			
		||||
 | 
			
		||||
    For each new token, this function:
 | 
			
		||||
    1. Tokenizes its description text
 | 
			
		||||
    2. Gets embeddings of the description tokens
 | 
			
		||||
    3. Averages them to initialize the new token's embedding
 | 
			
		||||
    4. Optionally adds Gaussian noise
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
 | 
			
		||||
        num_new_tokens: Number of new tokens added
 | 
			
		||||
        descriptions: Dict mapping token string to its description text
 | 
			
		||||
                      e.g., {"<think>": "A token representing reasoning process"}
 | 
			
		||||
        tokenizer: The tokenizer instance
 | 
			
		||||
        model: The model instance (used to get input embeddings)
 | 
			
		||||
        add_noise: Whether to add Gaussian noise to the initialization
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
        descriptions = {
 | 
			
		||||
            "<|START_OF_SVG|>": "Marks the beginning of an SVG document",
 | 
			
		||||
            "<|END_OF_SVG|>": "Marks the end of an SVG document"
 | 
			
		||||
        }
 | 
			
		||||
    """
 | 
			
		||||
    embedding_dim = embed_weight.size(1)
 | 
			
		||||
 | 
			
		||||
    for i, desc in enumerate(descriptions.values()):
 | 
			
		||||
        # Tokenize description text
 | 
			
		||||
        tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
 | 
			
		||||
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            token_ids = tokens["input_ids"][0]
 | 
			
		||||
            # Move to the same device as embed_weight
 | 
			
		||||
            device = embed_weight.device
 | 
			
		||||
            token_ids = token_ids.to(device)
 | 
			
		||||
 | 
			
		||||
            # Filter out new tokens (they don't have valid embeddings yet)
 | 
			
		||||
            valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
 | 
			
		||||
 | 
			
		||||
            if len(valid_token_ids) == 0:
 | 
			
		||||
                # Fallback: use mean of all existing embeddings
 | 
			
		||||
                logger.warning_rank0(
 | 
			
		||||
                    f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. "
 | 
			
		||||
                    "Using mean of existing embeddings."
 | 
			
		||||
                )
 | 
			
		||||
                base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
 | 
			
		||||
            else:
 | 
			
		||||
                # Get embeddings of description tokens and average them
 | 
			
		||||
                token_embeds = model.get_input_embeddings()(valid_token_ids)
 | 
			
		||||
                base_embedding = token_embeds.mean(dim=0)
 | 
			
		||||
 | 
			
		||||
            # Add noise if requested (ensure correct device and dtype)
 | 
			
		||||
            if add_noise:
 | 
			
		||||
                noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
 | 
			
		||||
                embed_weight[-num_new_tokens + i] = base_embedding + noise
 | 
			
		||||
            else:
 | 
			
		||||
                embed_weight[-num_new_tokens + i] = base_embedding
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _initialize_embeddings(
 | 
			
		||||
    embed_weight: "torch.Tensor",
 | 
			
		||||
    num_new_tokens: int,
 | 
			
		||||
    init_method: str,
 | 
			
		||||
    new_special_tokens_config: Optional[dict],
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
) -> None:
 | 
			
		||||
    """Single source of truth for embedding initialization.
 | 
			
		||||
 | 
			
		||||
    This function selects the appropriate initialization method and applies it.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        embed_weight: The embedding weight matrix to initialize
 | 
			
		||||
        num_new_tokens: Number of new tokens added
 | 
			
		||||
        init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
 | 
			
		||||
        new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
 | 
			
		||||
        tokenizer: The tokenizer instance
 | 
			
		||||
        model: The model instance
 | 
			
		||||
    """
 | 
			
		||||
    if init_method == "desc_init" and new_special_tokens_config:
 | 
			
		||||
        logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
 | 
			
		||||
        _description_based_initialization(
 | 
			
		||||
            embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
 | 
			
		||||
        )
 | 
			
		||||
    elif init_method == "desc_init_w_noise" and new_special_tokens_config:
 | 
			
		||||
        logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
 | 
			
		||||
        _description_based_initialization(
 | 
			
		||||
            embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        if init_method != "noise_init":
 | 
			
		||||
            logger.warning_rank0(
 | 
			
		||||
                f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
 | 
			
		||||
            )
 | 
			
		||||
        logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
 | 
			
		||||
        _noisy_mean_initialization(embed_weight, num_new_tokens)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def resize_embedding_layer(
 | 
			
		||||
    model: "PreTrainedModel",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    new_special_tokens_config: Optional[dict] = None,
 | 
			
		||||
    init_special_tokens: str = "noise_init",
 | 
			
		||||
) -> None:
 | 
			
		||||
    r"""Resize token embeddings and initialize new tokens.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        model: The model to resize
 | 
			
		||||
        tokenizer: The tokenizer (used to get target vocab size)
 | 
			
		||||
        new_special_tokens_config: Optional dict with token descriptions for semantic initialization
 | 
			
		||||
        init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
 | 
			
		||||
    """
 | 
			
		||||
    if is_deepspeed_zero3_enabled():
 | 
			
		||||
        import deepspeed  # type: ignore
 | 
			
		||||
 | 
			
		||||
@ -64,8 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
 | 
			
		||||
        with context_maybe_zero3:
 | 
			
		||||
            new_embedding_size = model.get_input_embeddings().weight.size(0)
 | 
			
		||||
            num_new_tokens = new_embedding_size - current_embedding_size
 | 
			
		||||
            _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
 | 
			
		||||
            _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
 | 
			
		||||
            logger.info_rank0(
 | 
			
		||||
                f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Initialize input embeddings
 | 
			
		||||
            _initialize_embeddings(
 | 
			
		||||
                model.get_input_embeddings().weight.data,
 | 
			
		||||
                num_new_tokens,
 | 
			
		||||
                init_special_tokens,
 | 
			
		||||
                new_special_tokens_config,
 | 
			
		||||
                tokenizer,
 | 
			
		||||
                model,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # Initialize output embeddings if not tied
 | 
			
		||||
            if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
 | 
			
		||||
                _initialize_embeddings(
 | 
			
		||||
                    model.get_output_embeddings().weight.data,
 | 
			
		||||
                    num_new_tokens,
 | 
			
		||||
                    init_special_tokens,
 | 
			
		||||
                    new_special_tokens_config,
 | 
			
		||||
                    tokenizer,
 | 
			
		||||
                    model,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        model.config.vocab_size = new_embedding_size
 | 
			
		||||
        logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
 | 
			
		||||
 | 
			
		||||
@ -175,7 +175,12 @@ def patch_model(
 | 
			
		||||
        prepare_valuehead_model(model)
 | 
			
		||||
 | 
			
		||||
    if model_args.resize_vocab:
 | 
			
		||||
        resize_embedding_layer(model, tokenizer)
 | 
			
		||||
        resize_embedding_layer(
 | 
			
		||||
            model,
 | 
			
		||||
            tokenizer,
 | 
			
		||||
            new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
 | 
			
		||||
            init_special_tokens=model_args.init_special_tokens,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if is_trainable:
 | 
			
		||||
        if getattr(model.config, "model_type", None) == "gemma3n":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user