mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 23:58:11 +08:00
support new special token #3420
Former-commit-id: f5c6a47f5193ab3a6c137580992bdcce0b31fdd5
This commit is contained in:
parent
12f852b8d4
commit
83404c4fa9
@ -26,11 +26,11 @@ class DataArguments:
|
|||||||
)
|
)
|
||||||
cutoff_len: int = field(
|
cutoff_len: int = field(
|
||||||
default=1024,
|
default=1024,
|
||||||
metadata={"help": "The cutoff length of the model inputs after tokenization."},
|
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
||||||
)
|
)
|
||||||
reserved_label_len: int = field(
|
reserved_label_len: int = field(
|
||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The minimum cutoff length reserved for label after tokenization."},
|
metadata={"help": "The minimum cutoff length reserved for the tokenized labels in the dataset."},
|
||||||
)
|
)
|
||||||
train_on_prompt: bool = field(
|
train_on_prompt: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
|
@ -31,11 +31,11 @@ class GeneratingArguments:
|
|||||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
metadata={"help": "Number of beams for beam search. 1 means no beam search."},
|
||||||
)
|
)
|
||||||
max_length: int = field(
|
max_length: int = field(
|
||||||
default=512,
|
default=1024,
|
||||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."},
|
||||||
)
|
)
|
||||||
max_new_tokens: int = field(
|
max_new_tokens: int = field(
|
||||||
default=512,
|
default=1024,
|
||||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."},
|
||||||
)
|
)
|
||||||
repetition_penalty: float = field(
|
repetition_penalty: float = field(
|
||||||
|
@ -33,6 +33,10 @@ class ModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||||
)
|
)
|
||||||
|
new_special_tokens: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Special tokens to be added into the tokenizer."},
|
||||||
|
)
|
||||||
model_revision: str = field(
|
model_revision: str = field(
|
||||||
default="main",
|
default="main",
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
@ -177,6 +181,9 @@ class ModelArguments:
|
|||||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
|
|
||||||
|
if self.new_special_tokens is not None: # support multiple special tokens
|
||||||
|
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
||||||
|
|
||||||
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||||
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
|
||||||
|
|
||||||
|
@ -67,6 +67,9 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
|||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
|
if model_args.resize_vocab:
|
||||||
|
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
||||||
raise ValueError("Cannot create new adapter upon a quantized model.")
|
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||||
|
|
||||||
@ -199,10 +202,11 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
if (
|
if (
|
||||||
training_args.do_train
|
training_args.do_train
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
|
and model_args.quantization_bit is None
|
||||||
and model_args.resize_vocab
|
and model_args.resize_vocab
|
||||||
and finetuning_args.additional_target is None
|
and finetuning_args.additional_target is None
|
||||||
):
|
):
|
||||||
logger.warning("Add token embeddings to `additional_target` to make the added tokens trainable.")
|
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
||||||
|
|
||||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||||
|
@ -157,6 +157,17 @@ def init_adapter(
|
|||||||
):
|
):
|
||||||
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
raise ValueError("DoRA is not compatible with PTQ-quantized models.")
|
||||||
|
|
||||||
|
if model_args.resize_vocab and finetuning_args.additional_target is None:
|
||||||
|
input_embeddings = model.get_input_embeddings()
|
||||||
|
output_embeddings = model.get_output_embeddings()
|
||||||
|
module_names = set()
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if module in [input_embeddings, output_embeddings]:
|
||||||
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
|
finetuning_args.additional_target = module_names
|
||||||
|
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||||
|
|
||||||
peft_kwargs = {
|
peft_kwargs = {
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
"target_modules": target_modules,
|
"target_modules": target_modules,
|
||||||
|
@ -39,6 +39,8 @@ def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:
|
|||||||
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
||||||
r"""
|
r"""
|
||||||
Loads pretrained tokenizer.
|
Loads pretrained tokenizer.
|
||||||
|
|
||||||
|
Note: including inplace operation of model_args.
|
||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
try:
|
try:
|
||||||
@ -57,6 +59,16 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
|
|||||||
**init_kwargs,
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_args.new_special_tokens is not None:
|
||||||
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
|
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||||
|
replace_additional_special_tokens=False,
|
||||||
|
)
|
||||||
|
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||||
|
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||||
|
model_args.resize_vocab = True
|
||||||
|
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
||||||
|
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
@ -42,9 +42,11 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
|||||||
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
current_embedding_size = model.get_input_embeddings().weight.size(0)
|
||||||
|
|
||||||
if len(tokenizer) > current_embedding_size:
|
if len(tokenizer) > current_embedding_size:
|
||||||
|
if getattr(model, "quantization_method", None):
|
||||||
|
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||||
|
|
||||||
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
|
||||||
logger.warning("Current model does not support resizing token embeddings.")
|
raise ValueError("Current model does not support resizing embedding layers.")
|
||||||
return
|
|
||||||
|
|
||||||
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
|
||||||
with context_maybe_zero3:
|
with context_maybe_zero3:
|
||||||
|
@ -30,6 +30,10 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
|||||||
|
|
||||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||||
if current_max_length and model_args.model_max_length > current_max_length:
|
if current_max_length and model_args.model_max_length > current_max_length:
|
||||||
|
logger.warning(
|
||||||
|
"Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length)
|
||||||
|
)
|
||||||
|
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||||
else:
|
else:
|
||||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user