mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[inference] fix stop token for object detection (#6624)
* fix stop token * update minicpm data pipeline * fix npu qlora examples Former-commit-id: e3e2c8c689c54ebb2af264de808502e5a8ba0f2b
This commit is contained in:
parent
089c7d5e51
commit
d8cba9464f
23
README.md
23
README.md
@ -403,12 +403,16 @@ Extra dependencies available: torch, torch-npu, metrics, deepspeed, liger-kernel
|
|||||||
|
|
||||||
<details><summary>For Windows users</summary>
|
<details><summary>For Windows users</summary>
|
||||||
|
|
||||||
|
#### Install BitsAndBytes
|
||||||
|
|
||||||
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you need to install a pre-built version of `bitsandbytes` library, which supports CUDA 11.1 to 12.2, please select the appropriate [release version](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels) based on your CUDA version.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Install Flash Attention-2
|
||||||
|
|
||||||
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
To enable FlashAttention-2 on the Windows platform, you need to install the precompiled `flash-attn` library, which supports CUDA 12.1 to 12.2. Please download the corresponding version from [flash-attention](https://github.com/bdashore3/flash-attention/releases) based on your requirements.
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
@ -444,9 +448,12 @@ If you cannot infer model on NPU devices, try setting `do_sample: false` in the
|
|||||||
|
|
||||||
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
Download the pre-built Docker images: [32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||||
|
|
||||||
To use nf4 QLoRA quantization based on bitsandbytes in Ascend NPU, please follow these 3 steps:
|
#### Install BitsAndBytes
|
||||||
|
|
||||||
|
To use QLoRA based on bitsandbytes on Ascend NPU, please follow these 3 steps:
|
||||||
|
|
||||||
|
1. Manually compile bitsandbytes: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x.
|
||||||
|
|
||||||
1. Manually compile bnb: Refer to [the installation documentation](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU) for the NPU version of bitsandbytes to complete the compilation and installation of bnb. The compilation requires a cmake version of at least 3.22.1 and a g++ version of at least 12.x.
|
|
||||||
```bash
|
```bash
|
||||||
# Install bitsandbytes from source
|
# Install bitsandbytes from source
|
||||||
# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch
|
# Clone bitsandbytes repo, Ascend NPU backend is currently enabled on multi-backend-refactor branch
|
||||||
@ -462,15 +469,19 @@ apt-get install -y build-essential cmake
|
|||||||
# Compile & install
|
# Compile & install
|
||||||
cmake -DCOMPUTE_BACKEND=npu -S .
|
cmake -DCOMPUTE_BACKEND=npu -S .
|
||||||
make
|
make
|
||||||
pip install -e .
|
pip install .
|
||||||
```
|
|
||||||
2. Install and use the main branch version of transformers.
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
2. Install transformers from the main branch.
|
||||||
|
|
||||||
|
```bash
|
||||||
git clone -b https://github.com/huggingface/transformers.git
|
git clone -b https://github.com/huggingface/transformers.git
|
||||||
cd transformers
|
cd transformers
|
||||||
pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
3. Set the double_quantization parameter to false in the training configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_otfq_npu.yaml) for guidance.
|
|
||||||
|
3. Set `double_quantization: false` in the configuration. You can refer to the [example](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml).
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
### Data Preparation
|
### Data Preparation
|
||||||
|
31
README_zh.md
31
README_zh.md
@ -404,19 +404,23 @@ pip install -e ".[torch,metrics]"
|
|||||||
|
|
||||||
<details><summary>Windows 用户指南</summary>
|
<details><summary>Windows 用户指南</summary>
|
||||||
|
|
||||||
|
#### 安装 BitsAndBytes
|
||||||
|
|
||||||
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
|
如果要在 Windows 平台上开启量化 LoRA(QLoRA),需要安装预编译的 `bitsandbytes` 库, 支持 CUDA 11.1 到 12.2, 请根据您的 CUDA 版本情况选择适合的[发布版本](https://github.com/jllllll/bitsandbytes-windows-webui/releases/tag/wheels)。
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 安装 Flash Attention-2
|
||||||
|
|
||||||
如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
|
如果要在 Windows 平台上开启 FlashAttention-2,需要安装预编译的 `flash-attn` 库,支持 CUDA 12.1 到 12.2,请根据需求到 [flash-attention](https://github.com/bdashore3/flash-attention/releases) 下载对应版本安装。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details><summary>昇腾 NPU 用户指南</summary>
|
<details><summary>昇腾 NPU 用户指南</summary>
|
||||||
|
|
||||||
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级Python到3.10及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
在昇腾 NPU 设备上安装 LLaMA Factory 时,请升级 Python 到 3.10 及以上,并需要指定额外依赖项,使用 `pip install -e ".[torch-npu,metrics]"` 命令安装。此外,还需要安装 **[Ascend CANN Toolkit 与 Kernels](https://www.hiascend.com/developer/download/community/result?module=cann)**,安装方法请参考[安装教程](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC2alpha002/quickstart/quickstart/quickstart_18_0004.html)或使用以下命令:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
# 请替换 URL 为 CANN 版本和设备型号对应的 URL
|
||||||
@ -445,11 +449,15 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
|
|
||||||
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
下载预构建 Docker 镜像:[32GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/130.html) | [64GB](http://mirrors.cn-central-221.ovaijisuan.com/detail/131.html)
|
||||||
|
|
||||||
如果要在 Ascend NPU中使用 基于bitsandbytes 的nf4 QLoRA量化,请执行如下3个步骤
|
#### 安装 BitsAndBytes
|
||||||
1. 手动编译bnb:请参考 bitsandbytes npu版本的[安装文档](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU)完成bnb的编译安装,编译要求环境cmake版本不低于3.22.1,g++版本不低于12.x
|
|
||||||
```
|
如果要在 Ascend NPU 上进行基于 bitsandbytes 的 QLoRA 量化微调,请执行如下步骤:
|
||||||
# 从源码安装bitsandbytes
|
|
||||||
# 克隆bitsandbytes仓库, Ascend NPU目前在multi-backend-refactor中支持
|
1. 手动编译 bitsandbytes:请参考[安装文档](https://huggingface.co/docs/bitsandbytes/installation?backend=Ascend+NPU&platform=Ascend+NPU)完成 NPU 版的 bitsandbytes 安装,编译要求环境 cmake 版本不低于 3.22.1,g++ 版本不低于 12.x。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 从源码安装 bitsandbytes
|
||||||
|
# 克隆 bitsandbytes 仓库, Ascend NPU 目前在 multi-backend-refactor 中支持
|
||||||
git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
|
git clone -b multi-backend-refactor https://github.com/bitsandbytes-foundation/bitsandbytes.git
|
||||||
cd bitsandbytes/
|
cd bitsandbytes/
|
||||||
|
|
||||||
@ -462,15 +470,18 @@ apt-get install -y build-essential cmake
|
|||||||
# 编译 & 安装
|
# 编译 & 安装
|
||||||
cmake -DCOMPUTE_BACKEND=npu -S .
|
cmake -DCOMPUTE_BACKEND=npu -S .
|
||||||
make
|
make
|
||||||
pip install -e .
|
pip install .
|
||||||
```
|
|
||||||
2. 安装使用transformers的main分支版本
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
2. 安装 transformers 的 main 分支版本。
|
||||||
|
|
||||||
|
```bash
|
||||||
git clone -b https://github.com/huggingface/transformers.git
|
git clone -b https://github.com/huggingface/transformers.git
|
||||||
cd transformers
|
cd transformers
|
||||||
pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
3. 设置训练参数中的double_quantization参数为false,可参考[示例](examples/train_qlora/llama3_lora_sft_otfq_npu.yaml)
|
|
||||||
|
3. 在训练参数中设置 `double_quantization: false`,可参考[示例](examples/train_qlora/llama3_lora_sft_bnb_npu.yaml)。
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
@ -109,6 +109,12 @@ USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
|
|||||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Supervised Fine-Tuning with 4-bit Bitsandbytes Quantization on Ascend NPU
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
#### Supervised Fine-Tuning with 4/8-bit GPTQ Quantization
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -109,6 +109,12 @@ USE_RAY=1 llamafactory-cli train examples/train_full/llama3_lora_sft_ray.yaml
|
|||||||
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_otfq.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 在 NPU 上基于 4 比特 Bitsandbytes 量化进行指令监督微调
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/train_qlora/llama3_lora_sft_bnb_npu.yaml
|
||||||
|
```
|
||||||
|
|
||||||
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
#### 基于 4/8 比特 GPTQ 量化进行指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
quantization_bit: 4
|
quantization_bit: 4
|
||||||
quantization_method: bitsandbytes # choices: [bitsandbytes (4/8), hqq (2/3/4/5/6/8), eetq (8)]
|
quantization_method: bitsandbytes
|
||||||
double_quantization: false
|
double_quantization: false
|
||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
|
|
@ -50,11 +50,15 @@ def vllm_infer(
|
|||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
max_new_tokens: int = 1024,
|
max_new_tokens: int = 1024,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
|
pipeline_parallel_size: int = 1,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||||
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||||
"""
|
"""
|
||||||
|
if pipeline_parallel_size > get_device_count():
|
||||||
|
raise ValueError("Pipeline parallel size should be smaller than the number of gpus.")
|
||||||
|
|
||||||
model_args, data_args, _, generating_args = get_infer_args(
|
model_args, data_args, _, generating_args = get_infer_args(
|
||||||
dict(
|
dict(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
@ -107,7 +111,7 @@ def vllm_infer(
|
|||||||
temperature=generating_args.temperature,
|
temperature=generating_args.temperature,
|
||||||
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||||
top_k=generating_args.top_k,
|
top_k=generating_args.top_k,
|
||||||
stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
|
||||||
max_tokens=generating_args.max_new_tokens,
|
max_tokens=generating_args.max_new_tokens,
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
@ -120,7 +124,8 @@ def vllm_infer(
|
|||||||
"model": model_args.model_name_or_path,
|
"model": model_args.model_name_or_path,
|
||||||
"trust_remote_code": True,
|
"trust_remote_code": True,
|
||||||
"dtype": model_args.infer_dtype,
|
"dtype": model_args.infer_dtype,
|
||||||
"tensor_parallel_size": get_device_count() or 1,
|
"tensor_parallel_size": (get_device_count() // pipeline_parallel_size) or 1,
|
||||||
|
"pipeline_parallel_size": pipeline_parallel_size,
|
||||||
"disable_log_stats": True,
|
"disable_log_stats": True,
|
||||||
"enable_lora": model_args.adapter_name_or_path is not None,
|
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||||
}
|
}
|
||||||
|
@ -133,7 +133,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
if repetition_penalty is not None
|
if repetition_penalty is not None
|
||||||
else generating_args["repetition_penalty"],
|
else generating_args["repetition_penalty"],
|
||||||
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"],
|
||||||
eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
eos_token_id=template.get_stop_token_ids(tokenizer),
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -168,7 +168,7 @@ class VllmEngine(BaseEngine):
|
|||||||
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0
|
||||||
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
top_k=top_k if top_k is not None else self.generating_args["top_k"],
|
||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
stop_token_ids=self.template.get_stop_token_ids(self.tokenizer),
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
skip_special_tokens=self.generating_args["skip_special_tokens"],
|
skip_special_tokens=self.generating_args["skip_special_tokens"],
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||||
@ -154,11 +153,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||||
|
|
||||||
if "image_bound" in features: # for minicpmv inputs
|
if "image_bound" in features: # for minicpmv inputs
|
||||||
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
features["position_ids"] = (
|
||||||
features["position_ids"] = pad_sequence(features["position_ids"], batch_first=True, padding_value=0)
|
torch.arange(features["input_ids"].size(1)).long().unsqueeze(0).expand_as(features["input_ids"])
|
||||||
new_features = {"data": features}
|
)
|
||||||
new_features.update({"labels": features["labels"]})
|
return {"data": features, "labels": features["labels"]}
|
||||||
features = new_features
|
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
@ -269,9 +269,10 @@ class CpmVPlugin(BasePlugin):
|
|||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
|
if len(images) != 0 and len(videos) != 0:
|
||||||
|
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
||||||
|
|
||||||
if len(videos) != 0:
|
if len(videos) != 0:
|
||||||
assert len(images) == 0, "Only support video and image sft seperately"
|
|
||||||
max_slice_nums = 2
|
max_slice_nums = 2
|
||||||
use_image_id = False
|
use_image_id = False
|
||||||
mm_inputs = self._get_mm_inputs([], videos, processor)
|
mm_inputs = self._get_mm_inputs([], videos, processor)
|
||||||
@ -286,10 +287,9 @@ class CpmVPlugin(BasePlugin):
|
|||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
content = content.replace(
|
|
||||||
VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1
|
|
||||||
)
|
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||||
|
|
||||||
@ -310,10 +310,7 @@ class CpmVPlugin(BasePlugin):
|
|||||||
final_text
|
final_text
|
||||||
+ text_chunks[i]
|
+ text_chunks[i]
|
||||||
+ image_processor.get_slice_image_placeholder(
|
+ image_processor.get_slice_image_placeholder(
|
||||||
image_sizes[0][i],
|
image_sizes[0][i], i, max_slice_nums, use_image_id
|
||||||
i,
|
|
||||||
max_slice_nums,
|
|
||||||
use_image_id,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
final_text += text_chunks[-1]
|
final_text += text_chunks[-1]
|
||||||
@ -338,7 +335,6 @@ class CpmVPlugin(BasePlugin):
|
|||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
|
|
||||||
if len(images) != 0:
|
if len(images) != 0:
|
||||||
images = self._regularize_images(
|
images = self._regularize_images(
|
||||||
images,
|
images,
|
||||||
@ -351,6 +347,7 @@ class CpmVPlugin(BasePlugin):
|
|||||||
for valid_image_nums in valid_image_nums_ls:
|
for valid_image_nums in valid_image_nums_ls:
|
||||||
new_images.append(images[idx : idx + valid_image_nums])
|
new_images.append(images[idx : idx + valid_image_nums])
|
||||||
idx += valid_image_nums
|
idx += valid_image_nums
|
||||||
|
|
||||||
images = new_images
|
images = new_images
|
||||||
|
|
||||||
image_inputs = image_processor(
|
image_inputs = image_processor(
|
||||||
@ -383,7 +380,6 @@ class CpmVPlugin(BasePlugin):
|
|||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
image_bounds_list = []
|
image_bounds_list = []
|
||||||
valid_image_nums_ls = []
|
valid_image_nums_ls = []
|
||||||
|
|
||||||
for input_ids in batch_ids:
|
for input_ids in batch_ids:
|
||||||
input_ids_ = torch.tensor(input_ids)
|
input_ids_ = torch.tensor(input_ids)
|
||||||
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (
|
||||||
@ -424,8 +420,8 @@ class LlavaPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
num_image_tokens += 1
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
@ -478,8 +474,8 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
else:
|
else:
|
||||||
image_seqlen = 1
|
image_seqlen = 1
|
||||||
|
|
||||||
num_image_tokens += 1
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
@ -529,8 +525,8 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
else:
|
else:
|
||||||
image_seqlen = 1
|
image_seqlen = 1
|
||||||
|
|
||||||
num_image_tokens += 1
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", self.image_token)
|
message["content"] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
@ -586,8 +582,8 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
num_image_tokens += 1
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", "")
|
message["content"] = content.replace("{{image}}", "")
|
||||||
|
|
||||||
@ -840,12 +836,12 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while IMAGE_PLACEHOLDER in content:
|
||||||
num_image_tokens += 1
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
num_video_tokens += 1
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1)
|
||||||
|
num_video_tokens += 1
|
||||||
|
|
||||||
content = content.replace("{{image}}", self.image_token)
|
content = content.replace("{{image}}", self.image_token)
|
||||||
message["content"] = content.replace("{{video}}", self.video_token)
|
message["content"] = content.replace("{{video}}", self.video_token)
|
||||||
|
@ -89,6 +89,16 @@ class Template:
|
|||||||
"""
|
"""
|
||||||
return self.format_tools.extract(content)
|
return self.format_tools.extract(content)
|
||||||
|
|
||||||
|
def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> List[int]:
|
||||||
|
r"""
|
||||||
|
Returns stop token ids.
|
||||||
|
"""
|
||||||
|
stop_token_ids = {tokenizer.eos_token_id}
|
||||||
|
for token in self.stop_words:
|
||||||
|
stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
|
||||||
|
|
||||||
|
return list(stop_token_ids)
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
@ -205,7 +215,7 @@ def _register_template(
|
|||||||
format_tools: Optional["Formatter"] = None,
|
format_tools: Optional["Formatter"] = None,
|
||||||
format_prefix: Optional["Formatter"] = None,
|
format_prefix: Optional["Formatter"] = None,
|
||||||
default_system: str = "",
|
default_system: str = "",
|
||||||
stop_words: Sequence[str] = [],
|
stop_words: Optional[Sequence[str]] = None,
|
||||||
efficient_eos: bool = False,
|
efficient_eos: bool = False,
|
||||||
replace_eos: bool = False,
|
replace_eos: bool = False,
|
||||||
replace_jinja_template: bool = False,
|
replace_jinja_template: bool = False,
|
||||||
@ -248,7 +258,7 @@ def _register_template(
|
|||||||
format_tools=format_tools or default_tool_formatter,
|
format_tools=format_tools or default_tool_formatter,
|
||||||
format_prefix=format_prefix or default_prefix_formatter,
|
format_prefix=format_prefix or default_prefix_formatter,
|
||||||
default_system=default_system,
|
default_system=default_system,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words or [],
|
||||||
efficient_eos=efficient_eos,
|
efficient_eos=efficient_eos,
|
||||||
replace_eos=replace_eos,
|
replace_eos=replace_eos,
|
||||||
replace_jinja_template=replace_jinja_template,
|
replace_jinja_template=replace_jinja_template,
|
||||||
@ -566,6 +576,7 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# copied from chatml template
|
||||||
_register_template(
|
_register_template(
|
||||||
name="cpm_v",
|
name="cpm_v",
|
||||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
@ -79,6 +79,8 @@ class CustomTrainer(Trainer):
|
|||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||||
|
|
||||||
|
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
|
||||||
"""
|
"""
|
||||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||||
|
@ -94,6 +94,8 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
||||||
|
|
||||||
|
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
|
||||||
"""
|
"""
|
||||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||||
|
@ -19,6 +19,7 @@ from subprocess import Popen, TimeoutExpired
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Generator, Optional
|
||||||
|
|
||||||
from transformers.trainer import TRAINING_ARGS_NAME
|
from transformers.trainer import TRAINING_ARGS_NAME
|
||||||
|
from transformers.utils import is_torch_npu_available
|
||||||
|
|
||||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
||||||
@ -172,6 +173,7 @@ class Runner:
|
|||||||
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
if get("top.quantization_bit") in QUANTIZATION_BITS:
|
||||||
args["quantization_bit"] = int(get("top.quantization_bit"))
|
args["quantization_bit"] = int(get("top.quantization_bit"))
|
||||||
args["quantization_method"] = get("top.quantization_method")
|
args["quantization_method"] = get("top.quantization_method")
|
||||||
|
args["double_quantization"] = not is_torch_npu_available()
|
||||||
|
|
||||||
# freeze config
|
# freeze config
|
||||||
if args["finetuning_type"] == "freeze":
|
if args["finetuning_type"] == "freeze":
|
||||||
|
@ -120,6 +120,12 @@ def test_jinja_template(use_fast: bool):
|
|||||||
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
assert tokenizer.apply_chat_template(MESSAGES) == ref_tokenizer.apply_chat_template(MESSAGES)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_stop_token_ids():
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(TINY_LLAMA)
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="llama3"))
|
||||||
|
assert set(template.get_stop_token_ids(tokenizer)) == {128008, 128009}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||||
@pytest.mark.parametrize("use_fast", [True, False])
|
@pytest.mark.parametrize("use_fast", [True, False])
|
||||||
def test_gemma_template(use_fast: bool):
|
def test_gemma_template(use_fast: bool):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user