diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 1adefbef..84920f4b 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -31,11 +31,20 @@ jobs:
- "ubuntu-latest"
- "windows-latest"
- "macos-13"
+ transformers:
+ - null
+ include: # test backward compatibility
+ - python: "3.9"
+ os: "ubuntu-latest"
+ transformers: "4.45.0"
+ - python: "3.9"
+ os: "ubuntu-latest"
+ transformers: "4.49.0"
runs-on: ${{ matrix.os }}
concurrency:
- group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}
+ group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
env:
@@ -51,19 +60,24 @@ jobs:
with:
python-version: ${{ matrix.python }}
cache: "pip"
- cache-dependency-path: "setup.py"
+ cache-dependency-path: "**/requirements*.txt"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[torch,dev]"
+ - name: Install transformers
+ if: ${{ matrix.transformers }}
+ run: |
+ python -m pip install "transformers==${{ matrix.transformers }}"
+
- name: Cache files
id: hf-hub-cache
uses: actions/cache@v4
with:
path: ${{ runner.temp }}/huggingface
- key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }}
+ key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}-${{ hashFiles('tests/version.txt') }}
- name: Check quality
run: |
diff --git a/README.md b/README.md
index 699b85e8..02e99136 100644
--- a/README.md
+++ b/README.md
@@ -243,11 +243,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
-| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
+| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
-| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
+| [InternVL2.5-3](https://huggingface.co/OpenGVLab/InternVL)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
@@ -415,11 +415,11 @@ huggingface-cli login
| Mandatory | Minimum | Recommend |
| ------------ | ------- | --------- |
| python | 3.9 | 3.10 |
-| torch | 1.13.1 | 2.6.0 |
-| transformers | 4.41.2 | 4.50.0 |
+| torch | 2.0.0 | 2.6.0 |
+| transformers | 4.45.0 | 4.50.0 |
| datasets | 2.16.0 | 3.2.0 |
| accelerate | 0.34.0 | 1.2.1 |
-| peft | 0.14.0 | 0.15.0 |
+| peft | 0.14.0 | 0.15.1 |
| trl | 0.8.6 | 0.9.6 |
| Optional | Minimum | Recommend |
@@ -428,7 +428,7 @@ huggingface-cli login
| deepspeed | 0.10.0 | 0.16.4 |
| bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.8.2 |
-| flash-attn | 2.3.0 | 2.7.2 |
+| flash-attn | 2.5.6 | 2.7.2 |
### Hardware Requirement
@@ -517,6 +517,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 |
| deepspeed | 0.13.2 | 0.13.2 |
+| vllm-ascend | - | 0.7.3 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
diff --git a/README_zh.md b/README_zh.md
index 058d443a..0a473b35 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -246,11 +246,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) |
| [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 |
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
-| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
+| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 |
| [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan |
| [Index](https://huggingface.co/IndexTeam) | 1.9B | index |
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
-| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
+| [InternVL2.5-3](https://huggingface.co/OpenGVLab/InternVL)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl |
| [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl |
| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - |
| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 |
@@ -418,11 +418,11 @@ huggingface-cli login
| 必需项 | 至少 | 推荐 |
| ------------ | ------- | --------- |
| python | 3.9 | 3.10 |
-| torch | 1.13.1 | 2.6.0 |
-| transformers | 4.41.2 | 4.50.0 |
+| torch | 2.0.0 | 2.6.0 |
+| transformers | 4.45.0 | 4.50.0 |
| datasets | 2.16.0 | 3.2.0 |
| accelerate | 0.34.0 | 1.2.1 |
-| peft | 0.14.0 | 0.15.0 |
+| peft | 0.14.0 | 0.15.1 |
| trl | 0.8.6 | 0.9.6 |
| 可选项 | 至少 | 推荐 |
@@ -431,7 +431,7 @@ huggingface-cli login
| deepspeed | 0.10.0 | 0.16.4 |
| bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.8.2 |
-| flash-attn | 2.3.0 | 2.7.2 |
+| flash-attn | 2.5.6 | 2.7.2 |
### 硬件依赖
@@ -521,6 +521,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh
| torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 |
| deepspeed | 0.13.2 | 0.13.2 |
+| vllm-ascend | - | 0.7.3 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
diff --git a/requirements.txt b/requirements.txt
index 7c26caa6..c818bb28 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
+transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.5.0
accelerate>=0.34.0,<=1.6.0
peft>=0.14.0,<=0.15.1
diff --git a/scripts/api_example/test_image.py b/scripts/api_example/test_image.py
index 77d6d7c8..afd2b69c 100644
--- a/scripts/api_example/test_image.py
+++ b/scripts/api_example/test_image.py
@@ -23,8 +23,8 @@ require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def main():
client = OpenAI(
- api_key="{}".format(os.environ.get("API_KEY", "0")),
- base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
+ api_key="{}".format(os.getenv("API_KEY", "0")),
+ base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
)
messages = []
messages.append(
diff --git a/scripts/api_example/test_toolcall.py b/scripts/api_example/test_toolcall.py
index 8e933914..e291ba69 100644
--- a/scripts/api_example/test_toolcall.py
+++ b/scripts/api_example/test_toolcall.py
@@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float:
def main():
client = OpenAI(
- api_key="{}".format(os.environ.get("API_KEY", "0")),
- base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
+ api_key="{}".format(os.getenv("API_KEY", "0")),
+ base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)),
)
tools = [
{
diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py
index 607355db..87a8959d 100644
--- a/src/llamafactory/__init__.py
+++ b/src/llamafactory/__init__.py
@@ -18,18 +18,11 @@ Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
- main:
- transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
- datasets>=2.16.0,<=3.5.0
- accelerate>=0.34.0,<=1.6.0
- peft>=0.14.0,<=0.15.1
- trl>=0.8.6,<=0.9.6
- attention:
- transformers>=4.42.4 (gemma+fa2)
- longlora:
- transformers>=4.41.2,<4.48.0
- packing:
- transformers>=4.43.0
+ transformers>=4.41.2,<=4.43.0,!=4.46.*,!=4.47.*,!=4.48.0
+ datasets>=2.16.0,<=3.5.0
+ accelerate>=0.34.0,<=1.6.0
+ peft>=0.14.0,<=0.15.1
+ trl>=0.8.6,<=0.9.6
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py
index ef39c417..20a3c190 100644
--- a/src/llamafactory/chat/hf_engine.py
+++ b/src/llamafactory/chat/hf_engine.py
@@ -25,7 +25,6 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer
from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
-from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response
@@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine):
inputs=inputs,
attention_mask=attention_mask,
generation_config=GenerationConfig(**generating_args),
- logits_processor=get_logits_processor(),
)
mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor)
diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py
index 8515e1bb..f9c32d43 100644
--- a/src/llamafactory/cli.py
+++ b/src/llamafactory/cli.py
@@ -19,7 +19,6 @@ from copy import deepcopy
from functools import partial
-
USAGE = (
"-" * 70
+ "\n"
diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py
index cfbeefd2..8f0e7080 100644
--- a/src/llamafactory/data/mm_plugin.py
+++ b/src/llamafactory/data/mm_plugin.py
@@ -25,12 +25,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
import numpy as np
import torch
-from transformers.image_utils import (
- get_image_size,
- make_batched_videos,
- make_flat_list_of_images,
- to_numpy_array,
-)
+from transformers.image_utils import get_image_size, to_numpy_array
from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@@ -62,6 +57,10 @@ if is_transformers_version_greater_than("4.45.0"):
)
+if is_transformers_version_greater_than("4.49.0"):
+ from transformers.image_utils import make_batched_videos, make_flat_list_of_images
+
+
if TYPE_CHECKING:
from av.stream import Stream
from numpy.typing import NDArray
@@ -487,61 +486,6 @@ class Gemma3Plugin(BasePlugin):
@dataclass
class InternVLPlugin(BasePlugin):
- @override
- def process_messages(
- self,
- messages: list[dict[str, str]],
- images: list["ImageInput"],
- videos: list["VideoInput"],
- audios: list["AudioInput"],
- processor: Optional["ProcessorMixin"],
- ) -> list[dict[str, str]]:
- self._validate_input(processor, images, videos, audios)
- num_image_tokens = 0
- num_video_tokens = 0
- image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
- messages = deepcopy(messages)
- mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
-
- image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images
- video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos
- video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video
-
- for message in messages:
- content = message["content"]
- while IMAGE_PLACEHOLDER in content:
- if num_image_tokens >= len(image_pixel_patch_list):
- raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
- content = content.replace(
- IMAGE_PLACEHOLDER,
- f"
{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}",
- 1,
- )
- num_image_tokens += 1
- message["content"] = content
-
- while VIDEO_PLACEHOLDER in content:
- if num_video_tokens >= len(video_patch_indices):
- raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
- current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
- end_patch_index = video_patch_indices[num_video_tokens]
- num_patches = list(video_num_patches[current_patch_index:end_patch_index])
- video_replaced_prompt = "\n".join(
- f"Frame{i + 1}:
{'' * image_seqlen * num_patches[i]}"
- for i in range(len(num_patches))
- )
- content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
- num_video_tokens += 1
- message["content"] = content
-
- if len(images) != num_image_tokens:
- raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
-
- if len(videos) != num_video_tokens:
- raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
-
- return messages
-
@override
def _get_mm_inputs(
self,
@@ -621,6 +565,63 @@ class InternVLPlugin(BasePlugin):
return mm_inputs
+ @override
+ def process_messages(
+ self,
+ messages: list[dict[str, str]],
+ images: list["ImageInput"],
+ videos: list["VideoInput"],
+ audios: list["AudioInput"],
+ processor: Optional["ProcessorMixin"],
+ ) -> list[dict[str, str]]:
+ self._validate_input(processor, images, videos, audios)
+ num_image_tokens = 0
+ num_video_tokens = 0
+ image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1
+ messages = deepcopy(messages)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+
+ image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images
+ video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos
+ video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ if num_image_tokens >= len(image_pixel_patch_list):
+ raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
+
+ content = content.replace(
+ IMAGE_PLACEHOLDER,
+ f"
{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}",
+ 1,
+ )
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ if num_video_tokens >= len(video_patch_indices):
+ raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
+
+ current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0
+ end_patch_index = video_patch_indices[num_video_tokens]
+ num_patches = list(video_num_patches[current_patch_index:end_patch_index])
+ video_replaced_prompt = "\n".join(
+ f"Frame{i + 1}:
{'' * image_seqlen * num_patches[i]}"
+ for i in range(len(num_patches))
+ )
+ content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1)
+ num_video_tokens += 1
+
+ message["content"] = content
+
+ if len(images) != num_image_tokens:
+ raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
+
+ if len(videos) != num_video_tokens:
+ raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
+
+ return messages
+
@override
def get_mm_inputs(
self,
@@ -634,12 +635,10 @@ class InternVLPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
-
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("image_num_patches", None)
mm_inputs.pop("video_patch_indices", None)
mm_inputs.pop("video_num_patches", None)
-
return mm_inputs
diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py
index b02d6df2..1691a001 100644
--- a/src/llamafactory/data/template.py
+++ b/src/llamafactory/data/template.py
@@ -871,6 +871,18 @@ register_template(
)
+register_template(
+ name="granite3_vision",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]),
+ default_system=(
+ "A chat between a curious user and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
+ ),
+ mm_plugin=get_mm_plugin(name="llava_next", image_token=""),
+)
+
+
register_template(
name="index",
format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]),
diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py
index e1c3aa43..025a8bbc 100644
--- a/src/llamafactory/extras/constants.py
+++ b/src/llamafactory/extras/constants.py
@@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
-AUDIO_PLACEHOLDER = os.environ.get("AUDIO_PLACEHOLDER", "