mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[breaking] bump transformers to 4.45.0 & improve ci (#7746)
* update ci * fix * fix * fix * fix * fix
This commit is contained in:
parent
4831552856
commit
0a0cfeb782
20
.github/workflows/tests.yml
vendored
20
.github/workflows/tests.yml
vendored
@ -31,11 +31,20 @@ jobs:
|
|||||||
- "ubuntu-latest"
|
- "ubuntu-latest"
|
||||||
- "windows-latest"
|
- "windows-latest"
|
||||||
- "macos-13"
|
- "macos-13"
|
||||||
|
transformers:
|
||||||
|
- null
|
||||||
|
include: # test backward compatibility
|
||||||
|
- python: "3.9"
|
||||||
|
os: "ubuntu-latest"
|
||||||
|
transformers: "4.45.0"
|
||||||
|
- python: "3.9"
|
||||||
|
os: "ubuntu-latest"
|
||||||
|
transformers: "4.49.0"
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}
|
||||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||||
|
|
||||||
env:
|
env:
|
||||||
@ -51,19 +60,24 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
cache: "pip"
|
cache: "pip"
|
||||||
cache-dependency-path: "setup.py"
|
cache-dependency-path: "**/requirements*.txt"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
python -m pip install ".[torch,dev]"
|
python -m pip install ".[torch,dev]"
|
||||||
|
|
||||||
|
- name: Install transformers
|
||||||
|
if: ${{ matrix.transformers }}
|
||||||
|
run: |
|
||||||
|
python -m pip install "transformers==${{ matrix.transformers }}"
|
||||||
|
|
||||||
- name: Cache files
|
- name: Cache files
|
||||||
id: hf-hub-cache
|
id: hf-hub-cache
|
||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ${{ runner.temp }}/huggingface
|
path: ${{ runner.temp }}/huggingface
|
||||||
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
|
key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}-${{ hashFiles('tests/version.txt') }}
|
||||||
|
|
||||||
- name: Check quality
|
- name: Check quality
|
||||||
run: |
|
run: |
|
||||||
|
13
README.md
13
README.md
@ -243,11 +243,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
|
|||||||
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
||||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 |
|
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 |
|
||||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||||
| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
|
| [InternVL2.5-3](https://huggingface.co/OpenGVLab/InternVL)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
|
||||||
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
@ -415,11 +415,11 @@ huggingface-cli login
|
|||||||
| Mandatory | Minimum | Recommend |
|
| Mandatory | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.9 | 3.10 |
|
| python | 3.9 | 3.10 |
|
||||||
| torch | 1.13.1 | 2.6.0 |
|
| torch | 2.0.0 | 2.6.0 |
|
||||||
| transformers | 4.41.2 | 4.50.0 |
|
| transformers | 4.45.0 | 4.50.0 |
|
||||||
| datasets | 2.16.0 | 3.2.0 |
|
| datasets | 2.16.0 | 3.2.0 |
|
||||||
| accelerate | 0.34.0 | 1.2.1 |
|
| accelerate | 0.34.0 | 1.2.1 |
|
||||||
| peft | 0.14.0 | 0.15.0 |
|
| peft | 0.14.0 | 0.15.1 |
|
||||||
| trl | 0.8.6 | 0.9.6 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
| Optional | Minimum | Recommend |
|
| Optional | Minimum | Recommend |
|
||||||
@ -428,7 +428,7 @@ huggingface-cli login
|
|||||||
| deepspeed | 0.10.0 | 0.16.4 |
|
| deepspeed | 0.10.0 | 0.16.4 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.8.2 |
|
| vllm | 0.4.3 | 0.8.2 |
|
||||||
| flash-attn | 2.3.0 | 2.7.2 |
|
| flash-attn | 2.5.6 | 2.7.2 |
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
@ -517,6 +517,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
| torch | 2.1.0 | 2.4.0 |
|
| torch | 2.1.0 | 2.4.0 |
|
||||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
|
| vllm-ascend | - | 0.7.3 |
|
||||||
|
|
||||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||||
|
|
||||||
|
13
README_zh.md
13
README_zh.md
@ -246,11 +246,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
|||||||
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
|
||||||
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 |
|
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 |
|
||||||
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
|
||||||
| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
|
||||||
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
|
||||||
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
|
||||||
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
|
||||||
| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
|
| [InternVL2.5-3](https://huggingface.co/OpenGVLab/InternVL)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
|
||||||
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
|
||||||
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
|
||||||
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
|
||||||
@ -418,11 +418,11 @@ huggingface-cli login
|
|||||||
| 必需项 | 至少 | 推荐 |
|
| 必需项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.9 | 3.10 |
|
| python | 3.9 | 3.10 |
|
||||||
| torch | 1.13.1 | 2.6.0 |
|
| torch | 2.0.0 | 2.6.0 |
|
||||||
| transformers | 4.41.2 | 4.50.0 |
|
| transformers | 4.45.0 | 4.50.0 |
|
||||||
| datasets | 2.16.0 | 3.2.0 |
|
| datasets | 2.16.0 | 3.2.0 |
|
||||||
| accelerate | 0.34.0 | 1.2.1 |
|
| accelerate | 0.34.0 | 1.2.1 |
|
||||||
| peft | 0.14.0 | 0.15.0 |
|
| peft | 0.14.0 | 0.15.1 |
|
||||||
| trl | 0.8.6 | 0.9.6 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
| 可选项 | 至少 | 推荐 |
|
| 可选项 | 至少 | 推荐 |
|
||||||
@ -431,7 +431,7 @@ huggingface-cli login
|
|||||||
| deepspeed | 0.10.0 | 0.16.4 |
|
| deepspeed | 0.10.0 | 0.16.4 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.8.2 |
|
| vllm | 0.4.3 | 0.8.2 |
|
||||||
| flash-attn | 2.3.0 | 2.7.2 |
|
| flash-attn | 2.5.6 | 2.7.2 |
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
@ -521,6 +521,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
|||||||
| torch | 2.1.0 | 2.4.0 |
|
| torch | 2.1.0 | 2.4.0 |
|
||||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||||
| deepspeed | 0.13.2 | 0.13.2 |
|
| deepspeed | 0.13.2 | 0.13.2 |
|
||||||
|
| vllm-ascend | - | 0.7.3 |
|
||||||
|
|
||||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
|
transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
|
||||||
datasets>=2.16.0,<=3.5.0
|
datasets>=2.16.0,<=3.5.0
|
||||||
accelerate>=0.34.0,<=1.6.0
|
accelerate>=0.34.0,<=1.6.0
|
||||||
peft>=0.14.0,<=0.15.1
|
peft>=0.14.0,<=0.15.1
|
||||||
|
@ -23,8 +23,8 @@ require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
api_key="{}".format(os.getenv("API_KEY", "0")),
|
||||||
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
|
||||||
)
|
)
|
||||||
messages = []
|
messages = []
|
||||||
messages.append(
|
messages.append(
|
||||||
|
@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
api_key="{}".format(os.getenv("API_KEY", "0")),
|
||||||
base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
|
base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
|
||||||
)
|
)
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -18,18 +18,11 @@ Level:
|
|||||||
api, webui > chat, eval, train > data, model > hparams > extras
|
api, webui > chat, eval, train > data, model > hparams > extras
|
||||||
|
|
||||||
Dependency graph:
|
Dependency graph:
|
||||||
main:
|
transformers>=4.41.2,<=4.43.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||||
transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
|
|
||||||
datasets>=2.16.0,<=3.5.0
|
datasets>=2.16.0,<=3.5.0
|
||||||
accelerate>=0.34.0,<=1.6.0
|
accelerate>=0.34.0,<=1.6.0
|
||||||
peft>=0.14.0,<=0.15.1
|
peft>=0.14.0,<=0.15.1
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
attention:
|
|
||||||
transformers>=4.42.4 (gemma+fa2)
|
|
||||||
longlora:
|
|
||||||
transformers>=4.41.2,<4.48.0
|
|
||||||
packing:
|
|
||||||
transformers>=4.43.0
|
|
||||||
|
|
||||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||||
Enable VRAM recording: RECORD_VRAM=1
|
Enable VRAM recording: RECORD_VRAM=1
|
||||||
|
@ -25,7 +25,6 @@ from typing_extensions import override
|
|||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
|
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
|
||||||
from ..extras.misc import get_logits_processor
|
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
|
|
||||||
@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
generation_config=GenerationConfig(**generating_args),
|
generation_config=GenerationConfig(**generating_args),
|
||||||
logits_processor=get_logits_processor(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
|
||||||
|
@ -19,7 +19,6 @@ from copy import deepcopy
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
USAGE = (
|
USAGE = (
|
||||||
"-" * 70
|
"-" * 70
|
||||||
+ "\n"
|
+ "\n"
|
||||||
|
@ -25,12 +25,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers.image_utils import (
|
from transformers.image_utils import get_image_size, to_numpy_array
|
||||||
get_image_size,
|
|
||||||
make_batched_videos,
|
|
||||||
make_flat_list_of_images,
|
|
||||||
to_numpy_array,
|
|
||||||
)
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
@ -62,6 +57,10 @@ if is_transformers_version_greater_than("4.45.0"):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_transformers_version_greater_than("4.49.0"):
|
||||||
|
from transformers.image_utils import make_batched_videos, make_flat_list_of_images
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from av.stream import Stream
|
from av.stream import Stream
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
@ -487,61 +486,6 @@ class Gemma3Plugin(BasePlugin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InternVLPlugin(BasePlugin):
|
class InternVLPlugin(BasePlugin):
|
||||||
@override
|
|
||||||
def process_messages(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, str]],
|
|
||||||
images: list["ImageInput"],
|
|
||||||
videos: list["VideoInput"],
|
|
||||||
audios: list["AudioInput"],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> list[dict[str, str]]:
|
|
||||||
self._validate_input(processor, images, videos, audios)
|
|
||||||
num_image_tokens = 0
|
|
||||||
num_video_tokens = 0
|
|
||||||
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
|
||||||
messages = deepcopy(messages)
|
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
|
||||||
|
|
||||||
image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images
|
|
||||||
video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos
|
|
||||||
video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
content = message["content"]
|
|
||||||
while IMAGE_PLACEHOLDER in content:
|
|
||||||
if num_image_tokens >= len(image_pixel_patch_list):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
content = content.replace(
|
|
||||||
IMAGE_PLACEHOLDER,
|
|
||||||
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
num_image_tokens += 1
|
|
||||||
message["content"] = content
|
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
|
||||||
if num_video_tokens >= len(video_patch_indices):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
|
||||||
end_patch_index = video_patch_indices[num_video_tokens]
|
|
||||||
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
|
||||||
video_replaced_prompt = "\n".join(
|
|
||||||
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
|
|
||||||
for i in range(len(num_patches))
|
|
||||||
)
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
|
|
||||||
num_video_tokens += 1
|
|
||||||
message["content"] = content
|
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -621,6 +565,63 @@ class InternVLPlugin(BasePlugin):
|
|||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
num_image_tokens = 0
|
||||||
|
num_video_tokens = 0
|
||||||
|
image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
|
||||||
|
image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images
|
||||||
|
video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos
|
||||||
|
video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if num_image_tokens >= len(image_pixel_patch_list):
|
||||||
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
content = content.replace(
|
||||||
|
IMAGE_PLACEHOLDER,
|
||||||
|
f"<img>{'<IMG_CONTEXT>' * image_seqlen * image_pixel_patch_list[num_image_tokens]}</img>",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
if num_video_tokens >= len(video_patch_indices):
|
||||||
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
|
||||||
|
end_patch_index = video_patch_indices[num_video_tokens]
|
||||||
|
num_patches = list(video_num_patches[current_patch_index:end_patch_index])
|
||||||
|
video_replaced_prompt = "\n".join(
|
||||||
|
f"Frame{i + 1}: <img>{'<IMG_CONTEXT>' * image_seqlen * num_patches[i]}</img>"
|
||||||
|
for i in range(len(num_patches))
|
||||||
|
)
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
|
||||||
|
num_video_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -634,12 +635,10 @@ class InternVLPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
mm_inputs.pop("image_num_patches", None)
|
mm_inputs.pop("image_num_patches", None)
|
||||||
mm_inputs.pop("video_patch_indices", None)
|
mm_inputs.pop("video_patch_indices", None)
|
||||||
mm_inputs.pop("video_num_patches", None)
|
mm_inputs.pop("video_num_patches", None)
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@ -871,6 +871,18 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="granite3_vision",
|
||||||
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]),
|
||||||
|
default_system=(
|
||||||
|
"A chat between a curious user and an artificial intelligence assistant. "
|
||||||
|
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
),
|
||||||
|
mm_plugin=get_mm_plugin(name="llava_next", image_token="<image>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="index",
|
name="index",
|
||||||
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
|
||||||
|
@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
|||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
AUDIO_PLACEHOLDER = os.environ.get("AUDIO_PLACEHOLDER", "<audio>")
|
AUDIO_PLACEHOLDER = os.getenv("AUDIO_PLACEHOLDER", "<audio>")
|
||||||
|
|
||||||
CHECKPOINT_NAMES = {
|
CHECKPOINT_NAMES = {
|
||||||
SAFE_ADAPTER_WEIGHTS_NAME,
|
SAFE_ADAPTER_WEIGHTS_NAME,
|
||||||
@ -50,7 +50,7 @@ FILEEXT2TYPE = {
|
|||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
IMAGE_PLACEHOLDER = os.environ.get("IMAGE_PLACEHOLDER", "<image>")
|
IMAGE_PLACEHOLDER = os.getenv("IMAGE_PLACEHOLDER", "<image>")
|
||||||
|
|
||||||
LAYERNORM_NAMES = {"norm", "ln"}
|
LAYERNORM_NAMES = {"norm", "ln"}
|
||||||
|
|
||||||
@ -89,7 +89,7 @@ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
|||||||
|
|
||||||
SWANLAB_CONFIG = "swanlab_public_config.json"
|
SWANLAB_CONFIG = "swanlab_public_config.json"
|
||||||
|
|
||||||
VIDEO_PLACEHOLDER = os.environ.get("VIDEO_PLACEHOLDER", "<video>")
|
VIDEO_PLACEHOLDER = os.getenv("VIDEO_PLACEHOLDER", "<video>")
|
||||||
|
|
||||||
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||||
|
|
||||||
@ -838,11 +838,46 @@ register_model_group(
|
|||||||
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-instruct",
|
DownloadSource.DEFAULT: "ibm-granite/granite-3.1-8b-instruct",
|
||||||
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-instruct",
|
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.1-8b-instruct",
|
||||||
},
|
},
|
||||||
|
"Granite-3.2-2B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "ibm-granite/granite-3.2-2b-instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.2-2b-instruct",
|
||||||
|
},
|
||||||
|
"Granite-3.2-8B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "ibm-granite/granite-3.2-8b-instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.2-8b-instruct",
|
||||||
|
},
|
||||||
|
"Granite-3.3-2B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-2b-base",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-2b-base",
|
||||||
|
},
|
||||||
|
"Granite-3.3-8B-Base": {
|
||||||
|
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-8b-base",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-8b-base",
|
||||||
|
},
|
||||||
|
"Granite-3.3-2B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-2b-instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-2b-instruct",
|
||||||
|
},
|
||||||
|
"Granite-3.3-8B-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "ibm-granite/granite-3.3-8b-instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-3.3-8b-instruct",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
template="granite3",
|
template="granite3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Granite-3.2-1B-A400M-Base": {
|
||||||
|
DownloadSource.DEFAULT: "ibm-granite/granite-vision-3.2-2b",
|
||||||
|
DownloadSource.MODELSCOPE: "AI-ModelScope/granite-vision-3.2-2b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="granite3_vision",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Hunyuan-7B-Instruct": {
|
"Hunyuan-7B-Instruct": {
|
||||||
@ -967,26 +1002,33 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternVL2_5-1B-MPO": {
|
"InternVL2.5-1B-MPO": {
|
||||||
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-1B-MPO-hf",
|
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-1B-MPO-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "llamafactory/InternVL2_5-1B-MPO-hf",
|
||||||
},
|
},
|
||||||
"InternVL2_5-2B-MPO": {
|
"InternVL2.5-2B-MPO": {
|
||||||
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-2B-MPO-hf",
|
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-2B-MPO-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "llamafactory/InternVL2_5-2B-MPO-hf",
|
||||||
},
|
},
|
||||||
"InternVL2_5-4B-MPO": {
|
"InternVL2.5-4B-MPO": {
|
||||||
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-4B-MPO-hf",
|
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-4B-MPO-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "llamafactory/InternVL2_5-4B-MPO-hf",
|
||||||
},
|
},
|
||||||
"InternVL2_5-8B-MPO": {
|
"InternVL2.5-8B-MPO": {
|
||||||
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-8B-MPO-hf",
|
DownloadSource.DEFAULT: "kingsley01/InternVL2_5-8B-MPO-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "llamafactory/InternVL2_5-8B-MPO-hf",
|
||||||
},
|
},
|
||||||
"InternVL3-1B-hf": {
|
"InternVL3-1B-hf": {
|
||||||
DownloadSource.DEFAULT: "kingsley01/InternVL3-1B-hf",
|
DownloadSource.DEFAULT: "kingsley01/InternVL3-1B-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "llamafactory/InternVL3-1B-hf",
|
||||||
},
|
},
|
||||||
"InternVL3-2B-hf": {
|
"InternVL3-2B-hf": {
|
||||||
DownloadSource.DEFAULT: "kingsley01/InternVL3-2B-hf",
|
DownloadSource.DEFAULT: "kingsley01/InternVL3-2B-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "llamafactory/InternVL3-2B-hf",
|
||||||
},
|
},
|
||||||
"InternVL3-8B-hf": {
|
"InternVL3-8B-hf": {
|
||||||
DownloadSource.DEFAULT: "kingsley01/InternVL3-8B-hf",
|
DownloadSource.DEFAULT: "kingsley01/InternVL3-8B-hf",
|
||||||
|
DownloadSource.MODELSCOPE: "llamafactory/InternVL3-8B-hf",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
template="intern_vl",
|
template="intern_vl",
|
||||||
|
@ -79,7 +79,7 @@ class _Logger(logging.Logger):
|
|||||||
|
|
||||||
def _get_default_logging_level() -> "logging._Level":
|
def _get_default_logging_level() -> "logging._Level":
|
||||||
r"""Return the default logging level."""
|
r"""Return the default logging level."""
|
||||||
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
|
env_level_str = os.getenv("LLAMAFACTORY_VERBOSITY", None)
|
||||||
if env_level_str:
|
if env_level_str:
|
||||||
if env_level_str.upper() in logging._nameToLevel:
|
if env_level_str.upper() in logging._nameToLevel:
|
||||||
return logging._nameToLevel[env_level_str.upper()]
|
return logging._nameToLevel[env_level_str.upper()]
|
||||||
|
@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
|||||||
|
|
||||||
def check_dependencies() -> None:
|
def check_dependencies() -> None:
|
||||||
r"""Check the version of the required packages."""
|
r"""Check the version of the required packages."""
|
||||||
check_version("transformers>=4.41.2,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
check_version("transformers>=4.43.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||||
check_version("datasets>=2.16.0,<=3.5.0")
|
check_version("datasets>=2.16.0,<=3.5.0")
|
||||||
check_version("accelerate>=0.34.0,<=1.6.0")
|
check_version("accelerate>=0.34.0,<=1.6.0")
|
||||||
check_version("peft>=0.14.0,<=0.15.1")
|
check_version("peft>=0.14.0,<=0.15.1")
|
||||||
@ -141,13 +141,13 @@ def count_parameters(model: "torch.nn.Module") -> tuple[int, int]:
|
|||||||
def get_current_device() -> "torch.device":
|
def get_current_device() -> "torch.device":
|
||||||
r"""Get the current available device."""
|
r"""Get the current available device."""
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||||
elif is_torch_mps_available():
|
elif is_torch_mps_available():
|
||||||
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||||
elif is_torch_cuda_available():
|
elif is_torch_cuda_available():
|
||||||
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||||
else:
|
else:
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
|
|
||||||
@ -155,11 +155,13 @@ def get_current_device() -> "torch.device":
|
|||||||
|
|
||||||
|
|
||||||
def get_device_count() -> int:
|
def get_device_count() -> int:
|
||||||
r"""Get the number of available GPU or NPU devices."""
|
r"""Get the number of available devices."""
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
return torch.xpu.device_count()
|
return torch.xpu.device_count()
|
||||||
elif is_torch_npu_available():
|
elif is_torch_npu_available():
|
||||||
return torch.npu.device_count()
|
return torch.npu.device_count()
|
||||||
|
elif is_torch_mps_available():
|
||||||
|
return torch.mps.device_count()
|
||||||
elif is_torch_cuda_available():
|
elif is_torch_cuda_available():
|
||||||
return torch.cuda.device_count()
|
return torch.cuda.device_count()
|
||||||
else:
|
else:
|
||||||
@ -175,10 +177,12 @@ def get_logits_processor() -> "LogitsProcessorList":
|
|||||||
|
|
||||||
def get_peak_memory() -> tuple[int, int]:
|
def get_peak_memory() -> tuple[int, int]:
|
||||||
r"""Get the peak memory usage for the current device (in Bytes)."""
|
r"""Get the peak memory usage for the current device (in Bytes)."""
|
||||||
if is_torch_npu_available():
|
if is_torch_xpu_available():
|
||||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
|
||||||
elif is_torch_xpu_available():
|
|
||||||
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
|
return torch.xpu.max_memory_allocated(), torch.xpu.max_memory_reserved()
|
||||||
|
elif is_torch_npu_available():
|
||||||
|
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||||
|
elif is_torch_mps_available():
|
||||||
|
return torch.mps.current_allocated_memory(), -1
|
||||||
elif is_torch_cuda_available():
|
elif is_torch_cuda_available():
|
||||||
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
|
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
|
||||||
else:
|
else:
|
||||||
@ -200,9 +204,11 @@ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def is_gpu_or_npu_available() -> bool:
|
def is_accelerator_available() -> bool:
|
||||||
r"""Check if the GPU or NPU is available."""
|
r"""Check if the accelerator is available."""
|
||||||
return is_torch_npu_available() or is_torch_cuda_available() or is_torch_xpu_available()
|
return (
|
||||||
|
is_torch_xpu_available() or is_torch_npu_available() or is_torch_mps_available() or is_torch_cuda_available()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||||
@ -229,7 +235,7 @@ def skip_check_imports() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def torch_gc() -> None:
|
def torch_gc() -> None:
|
||||||
r"""Collect GPU or NPU memory."""
|
r"""Collect the device memory."""
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if is_torch_xpu_available():
|
if is_torch_xpu_available():
|
||||||
torch.xpu.empty_cache()
|
torch.xpu.empty_cache()
|
||||||
@ -280,7 +286,7 @@ def use_ray() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def find_available_port() -> int:
|
def find_available_port() -> int:
|
||||||
"""Find an available port on the local machine."""
|
r"""Find an available port on the local machine."""
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
sock.bind(("", 0))
|
sock.bind(("", 0))
|
||||||
port = sock.getsockname()[1]
|
port = sock.getsockname()[1]
|
||||||
@ -288,8 +294,8 @@ def find_available_port() -> int:
|
|||||||
return port
|
return port
|
||||||
|
|
||||||
|
|
||||||
def fix_proxy(ipv6_enabled: bool) -> None:
|
def fix_proxy(ipv6_enabled: bool = False) -> None:
|
||||||
"""Fix proxy settings for gradio ui."""
|
r"""Fix proxy settings for gradio ui."""
|
||||||
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
|
os.environ["no_proxy"] = "localhost,127.0.0.1,0.0.0.0"
|
||||||
if ipv6_enabled:
|
if ipv6_enabled:
|
||||||
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
|
for name in ("http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"):
|
||||||
|
@ -19,7 +19,6 @@ import torch
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoConfig,
|
AutoConfig,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoModelForImageTextToText,
|
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForTextToWaveform,
|
AutoModelForTextToWaveform,
|
||||||
AutoModelForVision2Seq,
|
AutoModelForVision2Seq,
|
||||||
@ -30,6 +29,7 @@ from trl import AutoModelForCausalLMWithValueHead
|
|||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||||
|
from ..extras.packages import is_transformers_version_greater_than
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .model_utils.liger_kernel import apply_liger_kernel
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
from .model_utils.misc import register_autoclass
|
from .model_utils.misc import register_autoclass
|
||||||
@ -39,6 +39,10 @@ from .model_utils.valuehead import load_valuehead_params
|
|||||||
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
from .patcher import patch_config, patch_model, patch_processor, patch_tokenizer, patch_valuehead_model
|
||||||
|
|
||||||
|
|
||||||
|
if is_transformers_version_greater_than("4.46.0"):
|
||||||
|
from transformers import AutoModelForImageTextToText
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
||||||
|
|
||||||
@ -145,7 +149,10 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
if type(config) in AutoModelForVision2Seq._model_mapping.keys(): # image-text
|
||||||
load_class = AutoModelForVision2Seq
|
load_class = AutoModelForVision2Seq
|
||||||
elif type(config) in AutoModelForImageTextToText._model_mapping.keys(): # image-text
|
elif (
|
||||||
|
is_transformers_version_greater_than("4.46.0")
|
||||||
|
and type(config) in AutoModelForImageTextToText._model_mapping.keys()
|
||||||
|
): # image-text
|
||||||
load_class = AutoModelForImageTextToText
|
load_class = AutoModelForImageTextToText
|
||||||
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
elif type(config) in AutoModelForSeq2SeqLM._model_mapping.keys(): # audio-text
|
||||||
load_class = AutoModelForSeq2SeqLM
|
load_class = AutoModelForSeq2SeqLM
|
||||||
|
@ -18,7 +18,6 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl
|
|||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.constants import AttentionFunction
|
from ...extras.constants import AttentionFunction
|
||||||
from ...extras.misc import check_version
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -36,8 +35,6 @@ def configure_attn_implementation(
|
|||||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||||
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
||||||
if is_flash_attn_2_available():
|
if is_flash_attn_2_available():
|
||||||
check_version("transformers>=4.42.4")
|
|
||||||
check_version("flash_attn>=2.6.3")
|
|
||||||
if model_args.flash_attn != AttentionFunction.FA2:
|
if model_args.flash_attn != AttentionFunction.FA2:
|
||||||
logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
|
logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||||
model_args.flash_attn = AttentionFunction.FA2
|
model_args.flash_attn = AttentionFunction.FA2
|
||||||
|
@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
check_version("transformers>=4.41.2,<4.48.0")
|
check_version("transformers>=4.43.0,<4.48.0", mandatory=True)
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
|
@ -43,7 +43,6 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.misc import check_version
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
@ -117,6 +116,5 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
|||||||
if not is_trainable or not model_args.block_diag_attn:
|
if not is_trainable or not model_args.block_diag_attn:
|
||||||
return
|
return
|
||||||
|
|
||||||
check_version("transformers>=4.43.0")
|
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||||
|
@ -188,7 +188,7 @@ class LogCallback(TrainerCallback):
|
|||||||
self.webui_mode = is_env_enabled("LLAMABOARD_ENABLED")
|
self.webui_mode = is_env_enabled("LLAMABOARD_ENABLED")
|
||||||
if self.webui_mode and not use_ray():
|
if self.webui_mode and not use_ray():
|
||||||
signal.signal(signal.SIGABRT, self._set_abort)
|
signal.signal(signal.SIGABRT, self._set_abort)
|
||||||
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
self.logger_handler = logging.LoggerHandler(os.getenv("LLAMABOARD_WORKDIR"))
|
||||||
logging.add_handler(self.logger_handler)
|
logging.add_handler(self.logger_handler)
|
||||||
transformers.logging.add_handler(self.logger_handler)
|
transformers.logging.add_handler(self.logger_handler)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ...extras.misc import calculate_tps, get_logits_processor
|
from ...extras.misc import calculate_tps
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..trainer_utils import create_modelcard_and_push
|
from ..trainer_utils import create_modelcard_and_push
|
||||||
@ -82,7 +82,6 @@ def run_sft(
|
|||||||
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomSeq2SeqTrainer(
|
trainer = CustomSeq2SeqTrainer(
|
||||||
|
@ -77,10 +77,10 @@ class WebChatModel(ChatModel):
|
|||||||
if not lazy_init: # read arguments from command line
|
if not lazy_init: # read arguments from command line
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if demo_mode and os.environ.get("DEMO_MODEL") and os.environ.get("DEMO_TEMPLATE"): # load demo model
|
if demo_mode and os.getenv("DEMO_MODEL") and os.getenv("DEMO_TEMPLATE"): # load demo model
|
||||||
model_name_or_path = os.environ.get("DEMO_MODEL")
|
model_name_or_path = os.getenv("DEMO_MODEL")
|
||||||
template = os.environ.get("DEMO_TEMPLATE")
|
template = os.getenv("DEMO_TEMPLATE")
|
||||||
infer_backend = os.environ.get("DEMO_BACKEND", "huggingface")
|
infer_backend = os.getenv("DEMO_BACKEND", "huggingface")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
|
dict(model_name_or_path=model_name_or_path, template=template, infer_backend=infer_backend)
|
||||||
)
|
)
|
||||||
|
@ -23,7 +23,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
|||||||
from transformers.utils import is_torch_npu_available
|
from transformers.utils import is_torch_npu_available
|
||||||
|
|
||||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
from ..extras.misc import is_accelerator_available, torch_gc, use_ray
|
||||||
from ..extras.packages import is_gradio_available
|
from ..extras.packages import is_gradio_available
|
||||||
from .common import (
|
from .common import (
|
||||||
DEFAULT_CACHE_DIR,
|
DEFAULT_CACHE_DIR,
|
||||||
@ -108,7 +108,7 @@ class Runner:
|
|||||||
if not get("eval.output_dir"):
|
if not get("eval.output_dir"):
|
||||||
return ALERTS["err_no_output_dir"][lang]
|
return ALERTS["err_no_output_dir"][lang]
|
||||||
|
|
||||||
if not from_preview and not is_gpu_or_npu_available():
|
if not from_preview and not is_accelerator_available():
|
||||||
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
gr.Warning(ALERTS["warn_no_cuda"][lang])
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
@ -20,6 +20,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from llamafactory.data.mm_plugin import get_mm_plugin
|
from llamafactory.data.mm_plugin import get_mm_plugin
|
||||||
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.hparams import get_infer_args
|
from llamafactory.hparams import get_infer_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
|
|
||||||
@ -137,6 +138,7 @@ def test_base_plugin():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
@pytest.mark.skipif(not HF_TOKEN, reason="Gated model.")
|
||||||
|
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
|
||||||
def test_gemma3_plugin():
|
def test_gemma3_plugin():
|
||||||
image_seqlen = 256
|
image_seqlen = 256
|
||||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-3-4b-it")
|
||||||
@ -157,7 +159,7 @@ def test_gemma3_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="cache failure.")
|
@pytest.mark.xfail(reason="Unknown error.")
|
||||||
def test_internvl_plugin():
|
def test_internvl_plugin():
|
||||||
image_seqlen = 256
|
image_seqlen = 256
|
||||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="kingsley01/InternVL2_5-1B-MPO-hf")
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="kingsley01/InternVL2_5-1B-MPO-hf")
|
||||||
@ -196,6 +198,7 @@ def test_llama4_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
||||||
def test_llava_plugin():
|
def test_llava_plugin():
|
||||||
image_seqlen = 576
|
image_seqlen = 576
|
||||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="llava-hf/llava-1.5-7b-hf")
|
||||||
@ -254,6 +257,7 @@ def test_paligemma_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_transformers_version_greater_than("4.50.0"), reason="Requires transformers>=4.50.0")
|
||||||
def test_pixtral_plugin():
|
def test_pixtral_plugin():
|
||||||
image_slice_height, image_slice_width = 2, 2
|
image_slice_height, image_slice_width = 2, 2
|
||||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="mistral-community/pixtral-12b")
|
||||||
@ -291,6 +295,7 @@ def test_qwen2_vl_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_transformers_version_greater_than("4.47.0"), reason="Requires transformers>=4.47.0")
|
||||||
def test_video_llava_plugin():
|
def test_video_llava_plugin():
|
||||||
image_seqlen = 256
|
image_seqlen = 256
|
||||||
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="LanguageBind/Video-LLaVA-7B-hf")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user