From 8b21a60d9cc25f55f3a5905ec5c5f2d6c42fa944 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 6 Mar 2024 15:04:02 +0800 Subject: [PATCH] fix add tokens Former-commit-id: 9658c63cd94d28bba730a19f73397580b9865d6b --- README.md | 2 +- README_zh.md | 2 +- src/llmtuner/data/template.py | 9 +++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index b6e49c86..3a74966a 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ [24/02/28] We supported weight-decomposed LoRA (**[DoRA](https://arxiv.org/abs/2402.09353)**). Try `--use_dora` to activate DoRA training. -[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `tests/llama_pro.py` for usage. +[24/02/15] We supported **block expansion** proposed by [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro). See `scripts/llama_pro.py` for usage. [24/02/05] Qwen1.5 (Qwen2 beta version) series models are supported in LLaMA-Factory. Check this [blog post](https://qwenlm.github.io/blog/qwen1.5/) for details. diff --git a/README_zh.md b/README_zh.md index eaea39eb..f5342726 100644 --- a/README_zh.md +++ b/README_zh.md @@ -71,7 +71,7 @@ https://github.com/hiyouga/LLaMA-Factory/assets/16256802/ec36a9dd-37f4-4f72-81bd [24/02/28] 我们支持了 **[DoRA](https://arxiv.org/abs/2402.09353)** 微调。请使用 `--use_dora` 参数进行 DoRA 微调。 -[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `tests/llama_pro.py`。 +[24/02/15] 我们支持了 [LLaMA Pro](https://github.com/TencentARC/LLaMA-Pro) 提出的**块扩展**方法。详细用法请参照 `scripts/llama_pro.py`。 [24/02/05] Qwen1.5(Qwen2 测试版)系列模型已在 LLaMA-Factory 中实现微调支持。详情请查阅该[博客页面](https://qwenlm.github.io/zh/blog/qwen1.5/)。 diff --git a/src/llmtuner/data/template.py b/src/llmtuner/data/template.py index 895f2698..e5cbc6a5 100644 --- a/src/llmtuner/data/template.py +++ b/src/llmtuner/data/template.py @@ -264,15 +264,14 @@ def _register_template( def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: is_added = tokenizer.eos_token_id is None - is_oov = eos_token not in tokenizer.get_vocab() - tokenizer.add_special_tokens({"eos_token": eos_token}) + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) if is_added: logger.info("Add eos token: {}".format(tokenizer.eos_token)) else: logger.info("Replace eos token: {}".format(tokenizer.eos_token)) - if is_oov: + if num_added_tokens > 0: logger.warning("New tokens have been added, make sure `resize_vocab` is True.") @@ -368,10 +367,12 @@ def get_template_and_fix_tokenizer( logger.info("Add pad token: {}".format(tokenizer.pad_token)) if stop_words: - tokenizer.add_special_tokens( + num_added_tokens = tokenizer.add_special_tokens( dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False ) logger.info("Add {} to stop words.".format(",".join(stop_words))) + if num_added_tokens > 0: + logger.warning("New tokens have been added, make sure `resize_vocab` is True.") try: tokenizer.chat_template = _get_jinja_template(template, tokenizer)