mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[model] adds semantic initialization support for special tokens (#9267)
Co-authored-by: ximingxing <ximingxing@tencent.com>
This commit is contained in:
@@ -22,9 +22,12 @@ from typing import Any, Literal, Optional, Union
|
||||
import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||
from ..extras.logging import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
@@ -75,6 +78,28 @@ class BaseModelArguments:
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
new_special_tokens_config: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to YAML config with special token descriptions for semantic initialization. "
|
||||
"If set, this takes precedence over add_special_tokens. "
|
||||
"YAML format: {'<token>': 'description text', ...}"
|
||||
)
|
||||
},
|
||||
)
|
||||
init_special_tokens: Literal["noise_init", "desc_init", "desc_init_w_noise"] = field(
|
||||
default="noise_init",
|
||||
metadata={
|
||||
"help": (
|
||||
"Initialization method for new special tokens: "
|
||||
"'noise_init' (default, random noise around mean), "
|
||||
"'desc_init' (semantic initialization from descriptions), "
|
||||
"'desc_init_w_noise' (semantic + random noise). "
|
||||
"Note: 'desc_init' methods require new_special_tokens_config."
|
||||
)
|
||||
},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
@@ -185,8 +210,63 @@ class BaseModelArguments:
|
||||
if self.add_tokens is not None: # support multiple tokens
|
||||
self.add_tokens = [token.strip() for token in self.add_tokens.split(",")]
|
||||
|
||||
if self.add_special_tokens is not None: # support multiple special tokens
|
||||
# Process special tokens with priority: new_special_tokens_config > add_special_tokens
|
||||
if self.new_special_tokens_config is not None:
|
||||
# Priority 1: Load from YAML config (extracts both tokens and descriptions)
|
||||
try:
|
||||
cfg = OmegaConf.load(self.new_special_tokens_config)
|
||||
token_descriptions = OmegaConf.to_container(cfg)
|
||||
|
||||
if not isinstance(token_descriptions, dict):
|
||||
raise ValueError(
|
||||
f"YAML config must be a dictionary mapping tokens to descriptions. "
|
||||
f"Got: {type(token_descriptions)}"
|
||||
)
|
||||
|
||||
# Extract token list from config keys
|
||||
extracted_tokens = list(token_descriptions.keys())
|
||||
|
||||
# Warn if both are set
|
||||
if self.add_special_tokens is not None:
|
||||
logger.warning_rank0(
|
||||
"Both 'new_special_tokens_config' and 'add_special_tokens' are set. "
|
||||
f"Using tokens from config: {extracted_tokens}"
|
||||
)
|
||||
|
||||
# Override add_special_tokens with extracted tokens (as list)
|
||||
self.add_special_tokens = extracted_tokens
|
||||
|
||||
# Store descriptions internally for later use (internal attribute)
|
||||
self._special_token_descriptions = token_descriptions
|
||||
|
||||
logger.info_rank0(
|
||||
f"Loaded {len(extracted_tokens)} special tokens with descriptions from: "
|
||||
f"{self.new_special_tokens_config}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error_rank0(
|
||||
f"Failed to load special tokens config from '{self.new_special_tokens_config}': {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
elif self.add_special_tokens is not None:
|
||||
# Priority 2: Use simple comma-separated string (no descriptions)
|
||||
self.add_special_tokens = [token.strip() for token in self.add_special_tokens.split(",")]
|
||||
self._special_token_descriptions = None
|
||||
|
||||
else:
|
||||
# No special tokens to add
|
||||
self._special_token_descriptions = None
|
||||
|
||||
# Validate init method
|
||||
if self.init_special_tokens in ["desc_init", "desc_init_w_noise"]:
|
||||
if self._special_token_descriptions is None:
|
||||
logger.warning_rank0(
|
||||
f"init_special_tokens='{self.init_special_tokens}' requires new_special_tokens_config. "
|
||||
"Falling back to 'noise_init'"
|
||||
)
|
||||
self.init_special_tokens = "noise_init"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user