diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index f86291b0..8fcfcb42 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -69,6 +69,10 @@ class BaseModelArguments: default=None, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, ) + new_normal_tokens: Optional[str] = field( + default=None, + metadata={"help": "Normal tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, + ) model_revision: str = field( default="main", metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, @@ -176,6 +180,9 @@ class BaseModelArguments: if self.adapter_name_or_path is not None: # support merging multiple lora weights self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + if self.new_normal_tokens is not None: # support multiple normal tokens + self.new_normal_tokens = [token.strip() for token in self.new_normal_tokens.split(",")] + if self.new_special_tokens is not None: # support multiple special tokens self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 28014de9..52b25110 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -55,14 +55,24 @@ def patch_tokenizer(tokenizer: "PreTrainedTokenizer", model_args: "ModelArgument tokenizer.model_max_length = model_args.model_max_length # enlarge the tokenizer max length if model_args.new_special_tokens is not None: - num_added_tokens = tokenizer.add_special_tokens( + num_added_special_tokens = tokenizer.add_special_tokens( dict(additional_special_tokens=model_args.new_special_tokens), replace_additional_special_tokens=False, ) - logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) - if num_added_tokens > 0 and not model_args.resize_vocab: + logger.info_rank0("Add special tokens {} to vocab.".format(",".join(model_args.new_special_tokens))) + if num_added_special_tokens > 0 and not model_args.resize_vocab: model_args.resize_vocab = True - logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.") + logger.warning_rank0("New special tokens have been added, changed `resize_vocab` to True.") + + if model_args.new_normal_tokens is not None: + num_added_normal_tokens = tokenizer.add_tokens( + new_tokens=model_args.new_normal_tokens, + special_tokens=False, + ) + logger.info_rank0("Add normal tokens {} to vocab.".format(",".join(model_args.new_normal_tokens))) + if num_added_normal_tokens > 0 and not model_args.resize_vocab: + model_args.resize_vocab = True + logger.warning_rank0("New normal tokens have been added, changed `resize_vocab` to True.") def patch_processor(