From d8cba9464f6f788c1bf1d89061afd7fb90155e26 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 13 Jan 2025 21:34:20 +0800 Subject: [PATCH] [inference] fix stop token for object detection (#6624) * fix stop token * update minicpm data pipeline * fix npu qlora examples Former-commit-id: e3e2c8c689c54ebb2af264de808502e5a8ba0f2b --- README.md | 23 ++++++++++---- README_zh.md | 31 +++++++++++++------ examples/README.md | 6 ++++ examples/README_zh.md | 6 ++++ ..._npu.yaml => llama3_lora_sft_bnb_npu.yaml} | 2 +- scripts/vllm_infer.py | 9 ++++-- src/llamafactory/chat/hf_engine.py | 2 +- src/llamafactory/chat/vllm_engine.py | 2 +- src/llamafactory/data/collator.py | 10 +++--- src/llamafactory/data/mm_plugin.py | 28 +++++++---------- src/llamafactory/data/template.py | 15 +++++++-- src/llamafactory/train/pt/trainer.py | 2 ++ src/llamafactory/train/sft/trainer.py | 2 ++ src/llamafactory/webui/runner.py | 2 ++ tests/data/test_template.py | 6 ++++ 15 files changed, 101 insertions(+), 45 deletions(-) rename examples/train_qlora/{llama3_lora_sft_otfq_npu.yaml => llama3_lora_sft_bnb_npu.yaml} (88%) diff --git a/README.md b/README.md index 0fdea14c..4327c951 100644 --- a/README.md +++ b/README.md @@ -403,12 +403,16 @@ Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel
For Windows users +#### Install BitsAndBytes + If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version. ```bash pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl ``` +#### Install Flash Attention-2 + To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
@@ -444,9 +448,12 @@ If you cannot infer model on NPU devices, try setting `do_sample: false` in the Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) -To use nf4 QLoRA quantization based on bitsandbytes in Ascend NPU, please follow these 3 steps: +#### Install BitsAndBytes + +To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps: + +1. Manually compile bitsandbytes: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x. -1. Manually compile bnb: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation of bnb. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x. ```bash # Install bitsandbytes from source # Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch @@ -462,15 +469,19 @@ apt-get install -y build-essential cmake # Compile & install cmake -DCOMPUTE_BACKEND=npu -S . make -pip install -e . -``` -2. Install and use the main branch version of transformers. +pip install . ``` + +2. Install transformers from the main branch. + +```bash git clone -b https://github.com/huggingface/transformers.git cd transformers pip install . ``` -3. Set the double_quantization parameter to false in the training configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_otfq_npu.yaml) for guidance. + +3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml). + ### Data Preparation diff --git a/README_zh.md b/README_zh.md index 72e090f6..bf4eb4d3 100644 --- a/README_zh.md +++ b/README_zh.md @@ -404,19 +404,23 @@ pip install -e ".[torch,metrics]"
Windows 用户指南 +#### 安装 BitsAndBytes + 如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。 ```bash pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl ``` +#### 安装 Flash Attention-2 + 如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
昇腾 NPU 用户指南 -在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级Python到3.10及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: +在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令: ```bash # 请替换 URL 为 CANN 版本和设备型号对应的 URL @@ -445,11 +449,15 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh 下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html) -如果要在 Ascend NPU中使用 基于bitsandbytes 的nf4 QLoRA量化,请执行如下3个步骤 -1. 手动编译bnb:请参考 bitsandbytes npu版本的[安装文档](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU)完成bnb的编译安装,编译要求环境cmake版本不低于3.22.1,g++版本不低于12.x -``` -# 从源码安装bitsandbytes -# 克隆bitsandbytes仓库, Ascend NPU目前在multi-backend-refactor中支持 +#### 安装 BitsAndBytes + +如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤: + +1. 手动编译 bitsandbytes:请参考[安装文档](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU)完成 NPU 版的 bitsandbytes 安装,编译要求环境 cmake 版本不低于 3.22.1,g++ 版本不低于 12.x。 + +```bash +# 从源码安装 bitsandbytes +# 克隆 bitsandbytes 仓库, Ascend NPU 目前在 multi-backend-refactor 中支持 git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git cd bitsandbytes/ @@ -462,15 +470,18 @@ apt-get install -y build-essential cmake # 编译 & 安装 cmake -DCOMPUTE_BACKEND=npu -S . make -pip install -e . -``` -2. 安装使用transformers的main分支版本 +pip install . ``` + +2. 安装 transformers 的 main 分支版本。 + +```bash git clone -b https://github.com/huggingface/transformers.git cd transformers pip install . ``` -3. 设置训练参数中的double_quantization参数为false,可参考[示例](examples/train_qlora/llama3_lora_sft_otfq_npu.yaml) + +3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml)。
diff --git a/examples/README.md b/examples/README.md index 89f7d174..e589b980 100644 --- a/examples/README.md +++ b/examples/README.md @@ -109,6 +109,12 @@ USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml ``` +#### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml +``` + #### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization ```bash diff --git a/examples/README_zh.md b/examples/README_zh.md index 2c108e56..b75a6239 100644 --- a/examples/README_zh.md +++ b/examples/README_zh.md @@ -109,6 +109,12 @@ USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml ``` +#### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调 + +```bash +llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml +``` + #### 基于 4/8 比特 GPTQ 量化进行指令监督微调 ```bash diff --git a/examples/train_qlora/llama3_lora_sft_otfq_npu.yaml b/examples/train_qlora/llama3_lora_sft_bnb_npu.yaml similarity index 88% rename from examples/train_qlora/llama3_lora_sft_otfq_npu.yaml rename to examples/train_qlora/llama3_lora_sft_bnb_npu.yaml index 983acd39..babdc47c 100644 --- a/examples/train_qlora/llama3_lora_sft_otfq_npu.yaml +++ b/examples/train_qlora/llama3_lora_sft_bnb_npu.yaml @@ -1,7 +1,7 @@ ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct quantization_bit: 4 -quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)] +quantization_method: bitsandbytes double_quantization: false trust_remote_code: true diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index c9a8cfb6..907fb831 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -50,11 +50,15 @@ def vllm_infer( top_k: int = 50, max_new_tokens: int = 1024, repetition_penalty: float = 1.0, + pipeline_parallel_size: int = 1, ): r""" Performs batch generation using vLLM engine, which supports tensor parallelism. Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo """ + if pipeline_parallel_size > get_device_count(): + raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") + model_args, data_args, _, generating_args = get_infer_args( dict( model_name_or_path=model_name_or_path, @@ -107,7 +111,7 @@ def vllm_infer( temperature=generating_args.temperature, top_p=generating_args.top_p or 1.0, # top_p must > 0 top_k=generating_args.top_k, - stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, + stop_token_ids=template_obj.get_stop_token_ids(tokenizer), max_tokens=generating_args.max_new_tokens, skip_special_tokens=False, ) @@ -120,7 +124,8 @@ def vllm_infer( "model": model_args.model_name_or_path, "trust_remote_code": True, "dtype": model_args.infer_dtype, - "tensor_parallel_size": get_device_count() or 1, + "tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1, + "pipeline_parallel_size": pipeline_parallel_size, "disable_log_stats": True, "enable_lora": model_args.adapter_name_or_path is not None, } diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index f63b6434..879c407a 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -133,7 +133,7 @@ class HuggingfaceEngine(BaseEngine): if repetition_penalty is not None else generating_args["repetition_penalty"], length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], - eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, + eos_token_id=template.get_stop_token_ids(tokenizer), pad_token_id=tokenizer.pad_token_id, ) ) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index fd54b5a9..ee9c4c8c 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -168,7 +168,7 @@ class VllmEngine(BaseEngine): top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 top_k=top_k if top_k is not None else self.generating_args["top_k"], stop=stop, - stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + stop_token_ids=self.template.get_stop_token_ids(self.tokenizer), max_tokens=max_tokens, skip_special_tokens=self.generating_args["skip_special_tokens"], ) diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index dfd853ca..01742bdc 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence import torch import torch.nn.functional as F -from torch.nn.utils.rnn import pad_sequence from transformers import DataCollatorForSeq2Seq from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER @@ -154,11 +153,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features = features.data # use default_collate() instead of BatchEncoding.to() if "image_bound" in features: # for minicpmv inputs - features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]] - features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0) - new_features = {"data": features} - new_features.update({"labels": features["labels"]}) - features = new_features + features["position_ids"] = ( + torch.arange(features["input_ids"].size(1)).long().unsqueeze(0).expand_as(features["input_ids"]) + ) + return {"data": features, "labels": features["labels"]} return features diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 435496bb..7cae556a 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -269,9 +269,10 @@ class CpmVPlugin(BasePlugin): messages = deepcopy(messages) image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") mm_inputs = {} + if len(images) != 0 and len(videos) != 0: + raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") if len(videos) != 0: - assert len(images) == 0, "Only support video and image sft seperately" max_slice_nums = 2 use_image_id = False mm_inputs = self._get_mm_inputs([], videos, processor) @@ -286,10 +287,9 @@ class CpmVPlugin(BasePlugin): content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) while VIDEO_PLACEHOLDER in content: + video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1 + content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1) num_video_tokens += 1 - content = content.replace( - VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1 - ) message["content"] = content.replace("{{image}}", "(./)") @@ -310,10 +310,7 @@ class CpmVPlugin(BasePlugin): final_text + text_chunks[i] + image_processor.get_slice_image_placeholder( - image_sizes[0][i], - i, - max_slice_nums, - use_image_id, + image_sizes[0][i], i, max_slice_nums, use_image_id ) ) final_text += text_chunks[-1] @@ -338,7 +335,6 @@ class CpmVPlugin(BasePlugin): image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") mm_inputs = {} - if len(images) != 0: images = self._regularize_images( images, @@ -351,6 +347,7 @@ class CpmVPlugin(BasePlugin): for valid_image_nums in valid_image_nums_ls: new_images.append(images[idx : idx + valid_image_nums]) idx += valid_image_nums + images = new_images image_inputs = image_processor( @@ -383,7 +380,6 @@ class CpmVPlugin(BasePlugin): self._validate_input(images, videos) image_bounds_list = [] valid_image_nums_ls = [] - for input_ids in batch_ids: input_ids_ = torch.tensor(input_ids) start_cond = (input_ids_ == processor.tokenizer.im_start_id) | ( @@ -424,8 +420,8 @@ class LlavaPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 message["content"] = content.replace("{{image}}", self.image_token) @@ -478,8 +474,8 @@ class LlavaNextPlugin(BasePlugin): else: image_seqlen = 1 - num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 message["content"] = content.replace("{{image}}", self.image_token) @@ -529,8 +525,8 @@ class LlavaNextVideoPlugin(BasePlugin): else: image_seqlen = 1 - num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 message["content"] = content.replace("{{image}}", self.image_token) @@ -586,8 +582,8 @@ class PaliGemmaPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + num_image_tokens += 1 message["content"] = content.replace("{{image}}", "") @@ -840,12 +836,12 @@ class VideoLlavaPlugin(BasePlugin): for message in messages: content = message["content"] while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) + num_image_tokens += 1 while VIDEO_PLACEHOLDER in content: - num_video_tokens += 1 content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + num_video_tokens += 1 content = content.replace("{{image}}", self.image_token) message["content"] = content.replace("{{video}}", self.video_token) diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 60e34a8d..05eb0cda 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -89,6 +89,16 @@ class Template: """ return self.format_tools.extract(content) + def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]: + r""" + Returns stop token ids. + """ + stop_token_ids = {tokenizer.eos_token_id} + for token in self.stop_words: + stop_token_ids.add(tokenizer.convert_tokens_to_ids(token)) + + return list(stop_token_ids) + def _encode( self, tokenizer: "PreTrainedTokenizer", @@ -205,7 +215,7 @@ def _register_template( format_tools: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None, default_system: str = "", - stop_words: Sequence[str] = [], + stop_words: Optional[Sequence[str]] = None, efficient_eos: bool = False, replace_eos: bool = False, replace_jinja_template: bool = False, @@ -248,7 +258,7 @@ def _register_template( format_tools=format_tools or default_tool_formatter, format_prefix=format_prefix or default_prefix_formatter, default_system=default_system, - stop_words=stop_words, + stop_words=stop_words or [], efficient_eos=efficient_eos, replace_eos=replace_eos, replace_jinja_template=replace_jinja_template, @@ -566,6 +576,7 @@ _register_template( ) +# copied from chatml template _register_template( name="cpm_v", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 11d91111..5547a937 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -79,6 +79,8 @@ class CustomTrainer(Trainer): ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details. + + It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged. """ loss = super().compute_loss(model, inputs, return_outputs, **kwargs) if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 28ec25eb..95542eee 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -94,6 +94,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: r""" Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details. + + It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged. """ loss = super().compute_loss(model, inputs, return_outputs, **kwargs) if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index dc91ad50..fe1e643c 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -19,6 +19,7 @@ from subprocess import Popen, TimeoutExpired from typing import TYPE_CHECKING, Any, Dict, Generator, Optional from transformers.trainer import TRAINING_ARGS_NAME +from transformers.utils import is_torch_npu_available from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray @@ -172,6 +173,7 @@ class Runner: if get("top.quantization_bit") in QUANTIZATION_BITS: args["quantization_bit"] = int(get("top.quantization_bit")) args["quantization_method"] = get("top.quantization_method") + args["double_quantization"] = not is_torch_npu_available() # freeze config if args["finetuning_type"] == "freeze": diff --git a/tests/data/test_template.py b/tests/data/test_template.py index 3cf9227e..dead0af0 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -120,6 +120,12 @@ def test_jinja_template(use_fast: bool): assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES) +def test_get_stop_token_ids(): + tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA) + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3")) + assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009} + + @pytest.mark.skipif(not HF_TOKEN, reason="Gated model.") @pytest.mark.parametrize("use_fast", [True, False]) def test_gemma_template(use_fast: bool):