[model] adds semantic initialization support for special tokens (#9267)

Co-authored-by: ximingxing <ximingxing@tencent.com>
This commit is contained in:
Ximing Xing
2025-10-14 17:00:48 +08:00
committed by GitHub
parent 3dbca4b533
commit c867e28093
4 changed files with 264 additions and 7 deletions

View File

@@ -22,9 +22,12 @@ from typing import Any, Literal, Optional, Union
import torch
from transformers.training_args import _convert_str_dict
from typing_extensions import Self
from omegaconf import OmegaConf
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
from ..extras.logging import get_logger
logger = get_logger(__name__)
@dataclass
class BaseModelArguments:
@@ -75,6 +78,28 @@ class BaseModelArguments:
default=None,
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
)
new_special_tokens_config: Optional[str] = field(
default=None,
metadata={
"help": (
"Path to YAML config with special token descriptions for semantic initialization. "
"If set, this takes precedence over add_special_tokens. "
"YAML format: {'<token>': 'description text', ...}"
)
},
)
init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
default="noise_init",
metadata={
"help": (
"Initialization method for new special tokens: "
"'noise_init' (default, random noise around mean), "
"'desc_init' (semantic initialization from descriptions), "
"'desc_init_w_noise' (semantic + random noise). "
"Note: 'desc_init' methods require new_special_tokens_config."
)
},
)
model_revision: str = field(
default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
@@ -185,8 +210,63 @@ class BaseModelArguments:
if self.add_tokens is not None: # support multiple tokens
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
if self.add_special_tokens is not None: # support multiple special tokens
# Process special tokens with priority: new_special_tokens_config > add_special_tokens
if self.new_special_tokens_config is not None:
# Priority 1: Load from YAML config (extracts both tokens and descriptions)
try:
cfg = OmegaConf.load(self.new_special_tokens_config)
token_descriptions = OmegaConf.to_container(cfg)
if not isinstance(token_descriptions, dict):
raise ValueError(
f"YAML config must be a dictionary mapping tokens to descriptions. "
f"Got: {type(token_descriptions)}"
)
# Extract token list from config keys
extracted_tokens = list(token_descriptions.keys())
# Warn if both are set
if self.add_special_tokens is not None:
logger.warning_rank0(
"Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
f"Using tokens from config: {extracted_tokens}"
)
# Override add_special_tokens with extracted tokens (as list)
self.add_special_tokens = extracted_tokens
# Store descriptions internally for later use (internal attribute)
self._special_token_descriptions = token_descriptions
logger.info_rank0(
f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
f"{self.new_special_tokens_config}"
)
except Exception as e:
logger.error_rank0(
f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
)
raise
elif self.add_special_tokens is not None:
# Priority 2: Use simple comma-separated string (no descriptions)
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
self._special_token_descriptions = None
else:
# No special tokens to add
self._special_token_descriptions = None
# Validate init method
if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
if self._special_token_descriptions is None:
logger.warning_rank0(
f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
"Falling back to 'noise_init'"
)
self.init_special_tokens = "noise_init"
@dataclass