mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[misc] fix new tokens adding (#7253)
Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
parent
b8cddbc7d7
commit
1302ca39f6
@ -69,6 +69,10 @@ class BaseModelArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||||
)
|
)
|
||||||
|
new_normal_tokens: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Normal tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||||
|
)
|
||||||
model_revision: str = field(
|
model_revision: str = field(
|
||||||
default="main",
|
default="main",
|
||||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||||
@ -176,6 +180,9 @@ class BaseModelArguments:
|
|||||||
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
if self.adapter_name_or_path is not None: # support merging multiple lora weights
|
||||||
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
|
||||||
|
|
||||||
|
if self.new_normal_tokens is not None: # support multiple normal tokens
|
||||||
|
self.new_normal_tokens = [token.strip() for token in self.new_normal_tokens.split(",")]
|
||||||
|
|
||||||
if self.new_special_tokens is not None: # support multiple special tokens
|
if self.new_special_tokens is not None: # support multiple special tokens
|
||||||
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")]
|
||||||
|
|
||||||
|
@ -55,14 +55,24 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument
|
|||||||
tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length
|
tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length
|
||||||
|
|
||||||
if model_args.new_special_tokens is not None:
|
if model_args.new_special_tokens is not None:
|
||||||
num_added_tokens = tokenizer.add_special_tokens(
|
num_added_special_tokens = tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||||
replace_additional_special_tokens=False,
|
replace_additional_special_tokens=False,
|
||||||
)
|
)
|
||||||
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
logger.info_rank0("Add special tokens {} to vocab.".format(",".join(model_args.new_special_tokens)))
|
||||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
if num_added_special_tokens > 0 and not model_args.resize_vocab:
|
||||||
model_args.resize_vocab = True
|
model_args.resize_vocab = True
|
||||||
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
logger.warning_rank0("New special tokens have been added, changed `resize_vocab` to True.")
|
||||||
|
|
||||||
|
if model_args.new_normal_tokens is not None:
|
||||||
|
num_added_normal_tokens = tokenizer.add_tokens(
|
||||||
|
new_tokens=model_args.new_normal_tokens,
|
||||||
|
special_tokens=False,
|
||||||
|
)
|
||||||
|
logger.info_rank0("Add normal tokens {} to vocab.".format(",".join(model_args.new_normal_tokens)))
|
||||||
|
if num_added_normal_tokens > 0 and not model_args.resize_vocab:
|
||||||
|
model_args.resize_vocab = True
|
||||||
|
logger.warning_rank0("New normal tokens have been added, changed `resize_vocab` to True.")
|
||||||
|
|
||||||
|
|
||||||
def patch_processor(
|
def patch_processor(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user