mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-16 16:48:11 +08:00
[model] adds semantic initialization support for special tokens (#9267)
Co-authored-by: ximingxing <ximingxing@tencent.com>
This commit is contained in:
parent
3dbca4b533
commit
c867e28093
25
examples/extras/multi_tokens/tokens_cfg.yaml
Normal file
25
examples/extras/multi_tokens/tokens_cfg.yaml
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# SVG Container Tags
|
||||||
|
"<|START_OF_SVG|>": "Marks the beginning of an SVG document"
|
||||||
|
"<|END_OF_SVG|>": "Marks the end of an SVG document"
|
||||||
|
|
||||||
|
# SVG Group Tags
|
||||||
|
"<|start_of_g|>": "Begins a group element in SVG for organizing related shapes"
|
||||||
|
"<|end_of_g|>": "Ends a group element"
|
||||||
|
|
||||||
|
# SVG Shape Tags
|
||||||
|
"<|start_of_rect|>": "Begins a rectangle shape with width and height attributes"
|
||||||
|
"<|end_of_rect|>": "Ends a rectangle shape definition"
|
||||||
|
"<|start_of_circle|>": "Begins a circular shape with radius attribute"
|
||||||
|
"<|end_of_circle|>": "Ends a circular shape definition"
|
||||||
|
"<|start_of_path|>": "Begins a path element for drawing custom vector graphics"
|
||||||
|
"<|end_of_path|>": "Ends a path element definition"
|
||||||
|
"<|start_of_ellipse|>": "Begins an ellipse shape with x and y radii"
|
||||||
|
"<|end_of_ellipse|>": "Ends an ellipse shape definition"
|
||||||
|
|
||||||
|
# SVG Text Tags
|
||||||
|
"<|start_of_text|>": "Begins a text element for rendering text content"
|
||||||
|
"<|end_of_text|>": "Ends a text element"
|
||||||
|
|
||||||
|
# SVG Style Tags
|
||||||
|
"<|start_of_style|>": "Begins a style definition block for CSS styling"
|
||||||
|
"<|end_of_style|>": "Ends a style definition block"
|
@ -22,9 +22,12 @@ from typing import Any, Literal, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||||
|
from ..extras.logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelArguments:
|
class BaseModelArguments:
|
||||||
@ -75,6 +78,28 @@ class BaseModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||||
)
|
)
|
||||||
|
new_special_tokens_config: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Path to YAML config with special token descriptions for semantic initialization. "
|
||||||
|
"If set, this takes precedence over add_special_tokens. "
|
||||||
|
"YAML format: {'<token>': 'description text', ...}"
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
|
||||||
|
default="noise_init",
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Initialization method for new special tokens: "
|
||||||
|
"'noise_init' (default, random noise around mean), "
|
||||||
|
"'desc_init' (semantic initialization from descriptions), "
|
||||||
|
"'desc_init_w_noise' (semantic + random noise). "
|
||||||
|
"Note: 'desc_init' methods require new_special_tokens_config."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
model_revision: str = field(
|
model_revision: str = field(
|
||||||
default="main",
|
default="main",
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
@ -185,8 +210,63 @@ class BaseModelArguments:
|
|||||||
if self.add_tokens is not None: # support multiple tokens
|
if self.add_tokens is not None: # support multiple tokens
|
||||||
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
|
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
|
||||||
|
|
||||||
if self.add_special_tokens is not None: # support multiple special tokens
|
# Process special tokens with priority: new_special_tokens_config > add_special_tokens
|
||||||
|
if self.new_special_tokens_config is not None:
|
||||||
|
# Priority 1: Load from YAML config (extracts both tokens and descriptions)
|
||||||
|
try:
|
||||||
|
cfg = OmegaConf.load(self.new_special_tokens_config)
|
||||||
|
token_descriptions = OmegaConf.to_container(cfg)
|
||||||
|
|
||||||
|
if not isinstance(token_descriptions, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f"YAML config must be a dictionary mapping tokens to descriptions. "
|
||||||
|
f"Got: {type(token_descriptions)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract token list from config keys
|
||||||
|
extracted_tokens = list(token_descriptions.keys())
|
||||||
|
|
||||||
|
# Warn if both are set
|
||||||
|
if self.add_special_tokens is not None:
|
||||||
|
logger.warning_rank0(
|
||||||
|
"Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
|
||||||
|
f"Using tokens from config: {extracted_tokens}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Override add_special_tokens with extracted tokens (as list)
|
||||||
|
self.add_special_tokens = extracted_tokens
|
||||||
|
|
||||||
|
# Store descriptions internally for later use (internal attribute)
|
||||||
|
self._special_token_descriptions = token_descriptions
|
||||||
|
|
||||||
|
logger.info_rank0(
|
||||||
|
f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
|
||||||
|
f"{self.new_special_tokens_config}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error_rank0(
|
||||||
|
f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
elif self.add_special_tokens is not None:
|
||||||
|
# Priority 2: Use simple comma-separated string (no descriptions)
|
||||||
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
|
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
|
||||||
|
self._special_token_descriptions = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
# No special tokens to add
|
||||||
|
self._special_token_descriptions = None
|
||||||
|
|
||||||
|
# Validate init method
|
||||||
|
if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
|
||||||
|
if self._special_token_descriptions is None:
|
||||||
|
logger.warning_rank0(
|
||||||
|
f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
|
||||||
|
"Falling back to 'noise_init'"
|
||||||
|
)
|
||||||
|
self.init_special_tokens = "noise_init"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
@ -30,6 +30,14 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||||
|
"""Initialize new token embeddings with mean + Gaussian noise.
|
||||||
|
|
||||||
|
This is the default initialization method used by LlamaFactory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
||||||
|
num_new_tokens: Number of new tokens added at the end of the embedding matrix
|
||||||
|
"""
|
||||||
embedding_dim = embed_weight.size(1)
|
embedding_dim = embed_weight.size(1)
|
||||||
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||||
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
noise_weight = torch.empty_like(embed_weight[-num_new_tokens:])
|
||||||
@ -37,8 +45,125 @@ def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int
|
|||||||
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
embed_weight[-num_new_tokens:] = avg_weight + noise_weight
|
||||||
|
|
||||||
|
|
||||||
def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
|
def _description_based_initialization(
|
||||||
r"""Resize token embeddings."""
|
embed_weight: "torch.Tensor",
|
||||||
|
num_new_tokens: int,
|
||||||
|
descriptions: dict[str, str],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
add_noise: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize new token embeddings based on textual descriptions.
|
||||||
|
|
||||||
|
For each new token, this function:
|
||||||
|
1. Tokenizes its description text
|
||||||
|
2. Gets embeddings of the description tokens
|
||||||
|
3. Averages them to initialize the new token's embedding
|
||||||
|
4. Optionally adds Gaussian noise
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_weight: The embedding weight matrix to initialize (shape: [vocab_size, embedding_dim])
|
||||||
|
num_new_tokens: Number of new tokens added
|
||||||
|
descriptions: Dict mapping token string to its description text
|
||||||
|
e.g., {"<think>": "A token representing reasoning process"}
|
||||||
|
tokenizer: The tokenizer instance
|
||||||
|
model: The model instance (used to get input embeddings)
|
||||||
|
add_noise: Whether to add Gaussian noise to the initialization
|
||||||
|
|
||||||
|
Example:
|
||||||
|
descriptions = {
|
||||||
|
"<|START_OF_SVG|>": "Marks the beginning of an SVG document",
|
||||||
|
"<|END_OF_SVG|>": "Marks the end of an SVG document"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
embedding_dim = embed_weight.size(1)
|
||||||
|
|
||||||
|
for i, desc in enumerate(descriptions.values()):
|
||||||
|
# Tokenize description text
|
||||||
|
tokens = tokenizer(desc, return_tensors="pt", add_special_tokens=False)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
token_ids = tokens["input_ids"][0]
|
||||||
|
# Move to the same device as embed_weight
|
||||||
|
device = embed_weight.device
|
||||||
|
token_ids = token_ids.to(device)
|
||||||
|
|
||||||
|
# Filter out new tokens (they don't have valid embeddings yet)
|
||||||
|
valid_token_ids = token_ids[token_ids < (len(tokenizer) - num_new_tokens)]
|
||||||
|
|
||||||
|
if len(valid_token_ids) == 0:
|
||||||
|
# Fallback: use mean of all existing embeddings
|
||||||
|
logger.warning_rank0(
|
||||||
|
f"Description for token {i+1}/{num_new_tokens} contains no valid tokens. "
|
||||||
|
"Using mean of existing embeddings."
|
||||||
|
)
|
||||||
|
base_embedding = embed_weight[:-num_new_tokens].mean(dim=0)
|
||||||
|
else:
|
||||||
|
# Get embeddings of description tokens and average them
|
||||||
|
token_embeds = model.get_input_embeddings()(valid_token_ids)
|
||||||
|
base_embedding = token_embeds.mean(dim=0)
|
||||||
|
|
||||||
|
# Add noise if requested (ensure correct device and dtype)
|
||||||
|
if add_noise:
|
||||||
|
noise = torch.randn_like(base_embedding) * (1.0 / math.sqrt(embedding_dim))
|
||||||
|
embed_weight[-num_new_tokens + i] = base_embedding + noise
|
||||||
|
else:
|
||||||
|
embed_weight[-num_new_tokens + i] = base_embedding
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_embeddings(
|
||||||
|
embed_weight: "torch.Tensor",
|
||||||
|
num_new_tokens: int,
|
||||||
|
init_method: str,
|
||||||
|
new_special_tokens_config: Optional[dict],
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
) -> None:
|
||||||
|
"""Single source of truth for embedding initialization.
|
||||||
|
|
||||||
|
This function selects the appropriate initialization method and applies it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_weight: The embedding weight matrix to initialize
|
||||||
|
num_new_tokens: Number of new tokens added
|
||||||
|
init_method: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
|
||||||
|
new_special_tokens_config: Config dict with token descriptions (required for desc_init methods)
|
||||||
|
tokenizer: The tokenizer instance
|
||||||
|
model: The model instance
|
||||||
|
"""
|
||||||
|
if init_method == "desc_init" and new_special_tokens_config:
|
||||||
|
logger.info_rank0("Using semantic initialization (desc_init) for new special tokens")
|
||||||
|
_description_based_initialization(
|
||||||
|
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=False
|
||||||
|
)
|
||||||
|
elif init_method == "desc_init_w_noise" and new_special_tokens_config:
|
||||||
|
logger.info_rank0("Using semantic initialization with noise (desc_init_w_noise) for new special tokens")
|
||||||
|
_description_based_initialization(
|
||||||
|
embed_weight, num_new_tokens, new_special_tokens_config, tokenizer, model, add_noise=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if init_method != "noise_init":
|
||||||
|
logger.warning_rank0(
|
||||||
|
f"init_method='{init_method}' requires descriptions config, falling back to 'noise_init'"
|
||||||
|
)
|
||||||
|
logger.info_rank0("Using noisy mean initialization (noise_init) for new special tokens")
|
||||||
|
_noisy_mean_initialization(embed_weight, num_new_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_embedding_layer(
|
||||||
|
model: "PreTrainedModel",
|
||||||
|
tokenizer: "PreTrainedTokenizer",
|
||||||
|
new_special_tokens_config: Optional[dict] = None,
|
||||||
|
init_special_tokens: str = "noise_init",
|
||||||
|
) -> None:
|
||||||
|
r"""Resize token embeddings and initialize new tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to resize
|
||||||
|
tokenizer: The tokenizer (used to get target vocab size)
|
||||||
|
new_special_tokens_config: Optional dict with token descriptions for semantic initialization
|
||||||
|
init_special_tokens: Initialization method ('noise_init', 'desc_init', 'desc_init_w_noise')
|
||||||
|
"""
|
||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
import deepspeed # type: ignore
|
import deepspeed # type: ignore
|
||||||
|
|
||||||
@ -64,8 +189,30 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
|||||||
with context_maybe_zero3:
|
with context_maybe_zero3:
|
||||||
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
new_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||||
num_new_tokens = new_embedding_size - current_embedding_size
|
num_new_tokens = new_embedding_size - current_embedding_size
|
||||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
logger.info_rank0(
|
||||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
f"Resizing embeddings: {current_embedding_size} -> {new_embedding_size} (+{num_new_tokens} tokens)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize input embeddings
|
||||||
|
_initialize_embeddings(
|
||||||
|
model.get_input_embeddings().weight.data,
|
||||||
|
num_new_tokens,
|
||||||
|
init_special_tokens,
|
||||||
|
new_special_tokens_config,
|
||||||
|
tokenizer,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize output embeddings if not tied
|
||||||
|
if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings:
|
||||||
|
_initialize_embeddings(
|
||||||
|
model.get_output_embeddings().weight.data,
|
||||||
|
num_new_tokens,
|
||||||
|
init_special_tokens,
|
||||||
|
new_special_tokens_config,
|
||||||
|
tokenizer,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
|
||||||
model.config.vocab_size = new_embedding_size
|
model.config.vocab_size = new_embedding_size
|
||||||
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||||
|
@ -175,7 +175,12 @@ def patch_model(
|
|||||||
prepare_valuehead_model(model)
|
prepare_valuehead_model(model)
|
||||||
|
|
||||||
if model_args.resize_vocab:
|
if model_args.resize_vocab:
|
||||||
resize_embedding_layer(model, tokenizer)
|
resize_embedding_layer(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
new_special_tokens_config=getattr(model_args, "_special_token_descriptions", None),
|
||||||
|
init_special_tokens=model_args.init_special_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
if getattr(model.config, "model_type", None) == "gemma3n":
|
if getattr(model.config, "model_type", None) == "gemma3n":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user