mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 16:18:10 +08:00
[model] adds semantic initialization support for special tokens (#9267)
Co-authored-by: ximingxing <ximingxing@tencent.com>
This commit is contained in:
parent
3dbca4b533
commit
c867e28093
25
examples/extras/multi_tokens/tokens_cfg.yaml
Normal file
25
examples/extras/multi_tokens/tokens_cfg.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
# SVG Container Tags
|
||||
"<|START_OF_SVG|>": "Marks the beginning of an SVG document"
|
||||
"<|END_OF_SVG|>": "Marks the end of an SVG document"
|
||||
|
||||
# SVG Group Tags
|
||||
"<|start_of_g|>": "Begins a group element in SVG for organizing related shapes"
|
||||
"<|end_of_g|>": "Ends a group element"
|
||||
|
||||
# SVG Shape Tags
|
||||
"<|start_of_rect|>": "Begins a rectangle shape with width and height attributes"
|
||||
"<|end_of_rect|>": "Ends a rectangle shape definition"
|
||||
"<|start_of_circle|>": "Begins a circular shape with radius attribute"
|
||||
"<|end_of_circle|>": "Ends a circular shape definition"
|
||||
"<|start_of_path|>": "Begins a path element for drawing custom vector graphics"
|
||||
"<|end_of_path|>": "Ends a path element definition"
|
||||
"<|start_of_ellipse|>": "Begins an ellipse shape with x and y radii"
|
||||
"<|end_of_ellipse|>": "Ends an ellipse shape definition"
|
||||
|
||||
# SVG Text Tags
|
||||
"<|start_of_text|>": "Begins a text element for rendering text content"
|
||||
"<|end_of_text|>": "Ends a text element"
|
||||
|
||||
# SVG Style Tags
|
||||
"<|start_of_style|>": "Begins a style definition block for CSS styling"
|
||||
"<|end_of_style|>": "Ends a style definition block"
|
@ -22,9 +22,12 @@ from typing import Any, Literal, Optional, Union
|
||||
import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
@ -75,6 +78,28 @@ class BaseModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
new_special_tokens_config: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to YAML config with special token descriptions for semantic initialization. "
|
||||
"If set, this takes precedence over add_special_tokens. "
|
||||
"YAML format: {'<token>': 'description text', ...}"
|
||||
)
|
||||
},
|
||||
)
|
||||
init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
|
||||
default="noise_init",
|
||||
metadata={
|
||||
"help": (
|
||||
"Initialization method for new special tokens: "
|
||||
"'noise_init' (default, random noise around mean), "
|
||||
"'desc_init' (semantic initialization from descriptions), "
|
||||
"'desc_init_w_noise' (semantic + random noise). "
|
||||
"Note: 'desc_init' methods require new_special_tokens_config."
|
||||
)
|
||||
},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
@ -185,8 +210,63 @@ class BaseModelArguments:
|
||||
if self.add_tokens is not None: # support multiple tokens
|
||||
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
|
||||
|
||||
if self.add_special_tokens is not None: # support multiple special tokens
|
||||
# Process special tokens with priority: new_special_tokens_config > add_special_tokens
|
||||
if self.new_special_tokens_config is not None:
|
||||
# Priority 1: Load from YAML config (extracts both tokens and descriptions)
|
||||
try:
|
||||
cfg = OmegaConf.load(self.new_special_tokens_config)
|
||||
token_descriptions = OmegaConf.to_container(cfg)
|
||||
|
||||
if not isinstance(token_descriptions, dict):
|
||||
raise ValueError(
|
||||
f"YAML config must be a dictionary mapping tokens to descriptions. "
|
||||
f"Got: {type(token_descriptions)}"
|
||||
)
|
||||
|
||||
# Extract token list from config keys
|
||||
extracted_tokens = list(token_descriptions.keys())
|
||||
|
||||
# Warn if both are set
|
||||
if self.add_special_tokens is not None:
|
||||
logger.warning_rank0(
|
||||
"Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
|
||||
f"Using tokens from config: {extracted_tokens}"
|
||||
)
|
||||
|
||||
# Override add_special_tokens with extracted tokens (as list)
|
||||
self.add_special_tokens = extracted_tokens
|
||||
|
||||
# Store descriptions internally for later use (internal attribute)
|
||||
self._special_token_descriptions = token_descriptions
|
||||
|
||||
logger.info_rank0(
|
||||
f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
|
||||
f"{self.new_special_tokens_config}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error_rank0(
|
||||
f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
elif self.add_special_tokens is not None:
|
||||
# Priority 2: Use simple comma-separated string (no descriptions)
|
||||
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
|
||||
self._special_token_descriptions = None
|
||||
|
||||
else:
|
||||
# No special tokens to add
|
||||
self._special_token_descriptions = None
|
||||
|
||||
# Validate init method
|
||||
if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
|
||||
if self._special_token_descriptions is None:
|
||||
logger.warning_rank0(
|
||||
f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
|
||||
"Falling back to 'noise_init'"
|
||||
)
|
||||
self.init_special_tokens = "noise_init"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
@ -30,6 +30,14 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||
"""Initialize new token embeddings with mean + Gaussian noise.
|
||||
|
||||
This is the default initialization method used by LlamaFactory.
|
||||
|
||||
Args:
|
||||
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
||||
num_new_tokens: Number of new tokens added at the end of the embedding matrix
|
||||
"""
|
||||
embedding_dim = embed_weight.size(1)
|
||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||
@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
|
||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||
|
||||
|
||||
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
||||
r"""Resize token embeddings."""
|
||||
def _description_based_initialization(
|
||||
embed_weight: "torch.Tensor",
|
||||
num_new_tokens: int,
|
||||
descriptions: dict[str, str],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model: "PreTrainedModel",
|
||||
add_noise: bool = False,
|
||||
) -> None:
|
||||
"""Initialize new token embeddings based on textual descriptions.
|
||||
|
||||
For each new token, this function:
|
||||
1. Tokenizes its description text
|
||||
2. Gets embeddings of the description tokens
|
||||
3. Averages them to initialize the new token's embedding
|
||||
4. Optionally adds Gaussian noise
|
||||
|
||||
Args:
|
||||
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
||||
num_new_tokens: Number of new tokens added
|
||||
descriptions: Dict mapping token string to its description text
|
||||
e.g., {"<think>": "A token representing reasoning process"}
|
||||
tokenizer: The tokenizer instance
|
||||
model: The model instance (used to get input embeddings)
|
||||
add_noise: Whether to add Gaussian noise to the initialization
|
||||
|
||||
Example:
|
||||
descriptions = {
|
||||
"<|START_OF_SVG|>": "Marks the beginning of an SVG document",
|
||||
"<|END_OF_SVG|>": "Marks the end of an SVG document"
|
||||
}
|
||||
"""
|
||||
embedding_dim = embed_weight.size(1)
|
||||
|
||||
for i, desc in enumerate(descriptions.values()):
|
||||
# Tokenize description text
|
||||
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
|
||||
|
||||
with torch.no_grad():
|
||||
token_ids = tokens["input_ids"][0]
|
||||
# Move to the same device as embed_weight
|
||||
device = embed_weight.device
|
||||
token_ids = token_ids.to(device)
|
||||
|
||||
# Filter out new tokens (they don't have valid embeddings yet)
|
||||
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
|
||||
|
||||
if len(valid_token_ids) == 0:
|
||||
# Fallback: use mean of all existing embeddings
|
||||
logger.warning_rank0(
|
||||
f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. "
|
||||
"Using mean of existing embeddings."
|
||||
)
|
||||
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
|
||||
else:
|
||||
# Get embeddings of description tokens and average them
|
||||
token_embeds = model.get_input_embeddings()(valid_token_ids)
|
||||
base_embedding = token_embeds.mean(dim=0)
|
||||
|
||||
# Add noise if requested (ensure correct device and dtype)
|
||||
if add_noise:
|
||||
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
|
||||
embed_weight[-num_new_tokens + i] = base_embedding + noise
|
||||
else:
|
||||
embed_weight[-num_new_tokens + i] = base_embedding
|
||||
|
||||
|
||||
def _initialize_embeddings(
|
||||
embed_weight: "torch.Tensor",
|
||||
num_new_tokens: int,
|
||||
init_method: str,
|
||||
new_special_tokens_config: Optional[dict],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
model: "PreTrainedModel",
|
||||
) -> None:
|
||||
"""Single source of truth for embedding initialization.
|
||||
|
||||
This function selects the appropriate initialization method and applies it.
|
||||
|
||||
Args:
|
||||
embed_weight: The embedding weight matrix to initialize
|
||||
num_new_tokens: Number of new tokens added
|
||||
init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
|
||||
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
|
||||
tokenizer: The tokenizer instance
|
||||
model: The model instance
|
||||
"""
|
||||
if init_method == "desc_init" and new_special_tokens_config:
|
||||
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
|
||||
_description_based_initialization(
|
||||
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
|
||||
)
|
||||
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
|
||||
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
|
||||
_description_based_initialization(
|
||||
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
|
||||
)
|
||||
else:
|
||||
if init_method != "noise_init":
|
||||
logger.warning_rank0(
|
||||
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
|
||||
)
|
||||
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
|
||||
_noisy_mean_initialization(embed_weight, num_new_tokens)
|
||||
|
||||
|
||||
def resize_embedding_layer(
|
||||
model: "PreTrainedModel",
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
new_special_tokens_config: Optional[dict] = None,
|
||||
init_special_tokens: str = "noise_init",
|
||||
) -> None:
|
||||
r"""Resize token embeddings and initialize new tokens.
|
||||
|
||||
Args:
|
||||
model: The model to resize
|
||||
tokenizer: The tokenizer (used to get target vocab size)
|
||||
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
|
||||
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
|
||||
"""
|
||||
if is_deepspeed_zero3_enabled():
|
||||
import deepspeed # type: ignore
|
||||
|
||||
@ -64,8 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
||||
with context_maybe_zero3:
|
||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||
num_new_tokens = new_embedding_size - current_embedding_size
|
||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||
logger.info_rank0(
|
||||
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
|
||||
)
|
||||
|
||||
# Initialize input embeddings
|
||||
_initialize_embeddings(
|
||||
model.get_input_embeddings().weight.data,
|
||||
num_new_tokens,
|
||||
init_special_tokens,
|
||||
new_special_tokens_config,
|
||||
tokenizer,
|
||||
model,
|
||||
)
|
||||
|
||||
# Initialize output embeddings if not tied
|
||||
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||
_initialize_embeddings(
|
||||
model.get_output_embeddings().weight.data,
|
||||
num_new_tokens,
|
||||
init_special_tokens,
|
||||
new_special_tokens_config,
|
||||
tokenizer,
|
||||
model,
|
||||
)
|
||||
|
||||
model.config.vocab_size = new_embedding_size
|
||||
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||
|
@ -175,7 +175,12 @@ def patch_model(
|
||||
prepare_valuehead_model(model)
|
||||
|
||||
if model_args.resize_vocab:
|
||||
resize_embedding_layer(model, tokenizer)
|
||||
resize_embedding_layer(
|
||||
model,
|
||||
tokenizer,
|
||||
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
|
||||
init_special_tokens=model_args.init_special_tokens,
|
||||
)
|
||||
|
||||
if is_trainable:
|
||||
if getattr(model.config, "model_type", None) == "gemma3n":
|
||||
|
Loading…
x
Reference in New Issue
Block a user