diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1adefbef..84920f4b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -31,11 +31,20 @@ jobs: - "ubuntu-latest" - "windows-latest" - "macos-13" + transformers: + - null + include: # test backward compatibility + - python: "3.9" + os: "ubuntu-latest" + transformers: "4.45.0" + - python: "3.9" + os: "ubuntu-latest" + transformers: "4.49.0" runs-on: ${{ matrix.os }} concurrency: - group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }} + group: ${{ github.workflow }}-${{ github.ref }}-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }} cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} env: @@ -51,19 +60,24 @@ jobs: with: python-version: ${{ matrix.python }} cache: "pip" - cache-dependency-path: "setup.py" + cache-dependency-path: "**/requirements*.txt" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install ".[torch,dev]" + - name: Install transformers + if: ${{ matrix.transformers }} + run: | + python -m pip install "transformers==${{ matrix.transformers }}" + - name: Cache files id: hf-hub-cache uses: actions/cache@v4 with: path: ${{ runner.temp }}/huggingface - key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ hashFiles('tests/version.txt') }} + key: huggingface-${{ matrix.os }}-${{ matrix.python }}-${{ matrix.transformers }}-${{ hashFiles('tests/version.txt') }} - name: Check quality run: | diff --git a/README.md b/README.md index 699b85e8..02e99136 100644 --- a/README.md +++ b/README.md @@ -243,11 +243,11 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | -| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | +| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | -| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | +| [InternVL2.5-3](https://huggingface.co/OpenGVLab/InternVL)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | @@ -415,11 +415,11 @@ huggingface-cli login | Mandatory | Minimum | Recommend | | ------------ | ------- | --------- | | python | 3.9 | 3.10 | -| torch | 1.13.1 | 2.6.0 | -| transformers | 4.41.2 | 4.50.0 | +| torch | 2.0.0 | 2.6.0 | +| transformers | 4.45.0 | 4.50.0 | | datasets | 2.16.0 | 3.2.0 | | accelerate | 0.34.0 | 1.2.1 | -| peft | 0.14.0 | 0.15.0 | +| peft | 0.14.0 | 0.15.1 | | trl | 0.8.6 | 0.9.6 | | Optional | Minimum | Recommend | @@ -428,7 +428,7 @@ huggingface-cli login | deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | | vllm | 0.4.3 | 0.8.2 | -| flash-attn | 2.3.0 | 2.7.2 | +| flash-attn | 2.5.6 | 2.7.2 | ### Hardware Requirement @@ -517,6 +517,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh | torch | 2.1.0 | 2.4.0 | | torch-npu | 2.1.0 | 2.4.0.post2 | | deepspeed | 0.13.2 | 0.13.2 | +| vllm-ascend | - | 0.7.3 | Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. diff --git a/README_zh.md b/README_zh.md index 058d443a..0a473b35 100644 --- a/README_zh.md +++ b/README_zh.md @@ -246,11 +246,11 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Gemma 3](https://huggingface.co/google) | 1B/4B/12B/27B | gemma3/gemma (1B) | | [GLM-4/GLM-4-0414/GLM-Z1](https://huggingface.co/THUDM) | 9B/32B | glm4 | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | -| [Granite 3.0-3.1](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | +| [Granite 3.0-3.3](https://huggingface.co/ibm-granite) | 1B/2B/3B/8B | granite3 | | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | -| [InternVL2_5-3](https://huggingface.co/OpenGVLab/InternVL) | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | +| [InternVL2.5-3](https://huggingface.co/OpenGVLab/InternVL)\*\* | 1B/2B/4B/8B/9B/14B/26B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | @@ -418,11 +418,11 @@ huggingface-cli login | 必需项 | 至少 | 推荐 | | ------------ | ------- | --------- | | python | 3.9 | 3.10 | -| torch | 1.13.1 | 2.6.0 | -| transformers | 4.41.2 | 4.50.0 | +| torch | 2.0.0 | 2.6.0 | +| transformers | 4.45.0 | 4.50.0 | | datasets | 2.16.0 | 3.2.0 | | accelerate | 0.34.0 | 1.2.1 | -| peft | 0.14.0 | 0.15.0 | +| peft | 0.14.0 | 0.15.1 | | trl | 0.8.6 | 0.9.6 | | 可选项 | 至少 | 推荐 | @@ -431,7 +431,7 @@ huggingface-cli login | deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | | vllm | 0.4.3 | 0.8.2 | -| flash-attn | 2.3.0 | 2.7.2 | +| flash-attn | 2.5.6 | 2.7.2 | ### 硬件依赖 @@ -521,6 +521,7 @@ source /usr/local/Ascend/ascend-toolkit/set_env.sh | torch | 2.1.0 | 2.4.0 | | torch-npu | 2.1.0 | 2.4.0.post2 | | deepspeed | 0.13.2 | 0.13.2 | +| vllm-ascend | - | 0.7.3 | 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。 diff --git a/requirements.txt b/requirements.txt index 7c26caa6..c818bb28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0 +transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0 datasets>=2.16.0,<=3.5.0 accelerate>=0.34.0,<=1.6.0 peft>=0.14.0,<=0.15.1 diff --git a/scripts/api_example/test_image.py b/scripts/api_example/test_image.py index 77d6d7c8..afd2b69c 100644 --- a/scripts/api_example/test_image.py +++ b/scripts/api_example/test_image.py @@ -23,8 +23,8 @@ require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") def main(): client = OpenAI( - api_key="{}".format(os.environ.get("API_KEY", "0")), - base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), + api_key="{}".format(os.getenv("API_KEY", "0")), + base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)), ) messages = [] messages.append( diff --git a/scripts/api_example/test_toolcall.py b/scripts/api_example/test_toolcall.py index 8e933914..e291ba69 100644 --- a/scripts/api_example/test_toolcall.py +++ b/scripts/api_example/test_toolcall.py @@ -33,8 +33,8 @@ def calculate_gpa(grades: list[str], hours: list[int]) -> float: def main(): client = OpenAI( - api_key="{}".format(os.environ.get("API_KEY", "0")), - base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)), + api_key="{}".format(os.getenv("API_KEY", "0")), + base_url="http://localhost:{}/v1".format(os.getenv("API_PORT", 8000)), ) tools = [ { diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index 607355db..87a8959d 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -18,18 +18,11 @@ Level: api, webui > chat, eval, train > data, model > hparams > extras Dependency graph: - main: - transformers>=4.41.2,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0 - datasets>=2.16.0,<=3.5.0 - accelerate>=0.34.0,<=1.6.0 - peft>=0.14.0,<=0.15.1 - trl>=0.8.6,<=0.9.6 - attention: - transformers>=4.42.4 (gemma+fa2) - longlora: - transformers>=4.41.2,<4.48.0 - packing: - transformers>=4.43.0 + transformers>=4.41.2,<=4.43.0,!=4.46.*,!=4.47.*,!=4.48.0 + datasets>=2.16.0,<=3.5.0 + accelerate>=0.34.0,<=1.6.0 + peft>=0.14.0,<=0.15.1 + trl>=0.8.6,<=0.9.6 Disable version checking: DISABLE_VERSION_CHECK=1 Enable VRAM recording: RECORD_VRAM=1 diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index ef39c417..20a3c190 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -25,7 +25,6 @@ from typing_extensions import override from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName -from ..extras.misc import get_logits_processor from ..model import load_model, load_tokenizer from .base_engine import BaseEngine, Response @@ -178,7 +177,6 @@ class HuggingfaceEngine(BaseEngine): inputs=inputs, attention_mask=attention_mask, generation_config=GenerationConfig(**generating_args), - logits_processor=get_logits_processor(), ) mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 8515e1bb..f9c32d43 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -19,7 +19,6 @@ from copy import deepcopy from functools import partial - USAGE = ( "-" * 70 + "\n" diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index cfbeefd2..8f0e7080 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -25,12 +25,7 @@ from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union import numpy as np import torch -from transformers.image_utils import ( - get_image_size, - make_batched_videos, - make_flat_list_of_images, - to_numpy_array, -) +from transformers.image_utils import get_image_size, to_numpy_array from typing_extensions import override from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER @@ -62,6 +57,10 @@ if is_transformers_version_greater_than("4.45.0"): ) +if is_transformers_version_greater_than("4.49.0"): + from transformers.image_utils import make_batched_videos, make_flat_list_of_images + + if TYPE_CHECKING: from av.stream import Stream from numpy.typing import NDArray @@ -487,61 +486,6 @@ class Gemma3Plugin(BasePlugin): @dataclass class InternVLPlugin(BasePlugin): - @override - def process_messages( - self, - messages: list[dict[str, str]], - images: list["ImageInput"], - videos: list["VideoInput"], - audios: list["AudioInput"], - processor: Optional["ProcessorMixin"], - ) -> list[dict[str, str]]: - self._validate_input(processor, images, videos, audios) - num_image_tokens = 0 - num_video_tokens = 0 - image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1 - messages = deepcopy(messages) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) - - image_pixel_patch_list = mm_inputs.get("image_num_patches", None) # pathes of images - video_num_patches = mm_inputs.get("video_num_patches", None) # all patches for frames of videos - video_patch_indices = mm_inputs.get("video_patch_indices", None) # num frames of per video - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - if num_image_tokens >= len(image_pixel_patch_list): - raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") - content = content.replace( - IMAGE_PLACEHOLDER, - f"{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}", - 1, - ) - num_image_tokens += 1 - message["content"] = content - - while VIDEO_PLACEHOLDER in content: - if num_video_tokens >= len(video_patch_indices): - raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") - current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 - end_patch_index = video_patch_indices[num_video_tokens] - num_patches = list(video_num_patches[current_patch_index:end_patch_index]) - video_replaced_prompt = "\n".join( - f"Frame{i + 1}: {'' * image_seqlen * num_patches[i]}" - for i in range(len(num_patches)) - ) - content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1) - num_video_tokens += 1 - message["content"] = content - - if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") - - if len(videos) != num_video_tokens: - raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") - - return messages - @override def _get_mm_inputs( self, @@ -621,6 +565,63 @@ class InternVLPlugin(BasePlugin): return mm_inputs + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["ProcessorMixin"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + num_image_tokens = 0 + num_video_tokens = 0 + image_seqlen = getattr(processor, "image_seq_length") if self.expand_mm_tokens else 1 + messages = deepcopy(messages) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + + image_pixel_patch_list = mm_inputs.get("image_num_patches") # pathes of images + video_num_patches = mm_inputs.get("video_num_patches") # all patches for frames of videos + video_patch_indices = mm_inputs.get("video_patch_indices") # num frames of per video + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + if num_image_tokens >= len(image_pixel_patch_list): + raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.") + + content = content.replace( + IMAGE_PLACEHOLDER, + f"{'' * image_seqlen * image_pixel_patch_list[num_image_tokens]}", + 1, + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + if num_video_tokens >= len(video_patch_indices): + raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.") + + current_patch_index = video_patch_indices[num_video_tokens - 1] if num_video_tokens > 0 else 0 + end_patch_index = video_patch_indices[num_video_tokens] + num_patches = list(video_num_patches[current_patch_index:end_patch_index]) + video_replaced_prompt = "\n".join( + f"Frame{i + 1}: {'' * image_seqlen * num_patches[i]}" + for i in range(len(num_patches)) + ) + content = content.replace(VIDEO_PLACEHOLDER, video_replaced_prompt, 1) + num_video_tokens += 1 + + message["content"] = content + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + if len(videos) != num_video_tokens: + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") + + return messages + @override def get_mm_inputs( self, @@ -634,12 +635,10 @@ class InternVLPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) - mm_inputs = self._get_mm_inputs(images, videos, audios, processor) mm_inputs.pop("image_num_patches", None) mm_inputs.pop("video_patch_indices", None) mm_inputs.pop("video_num_patches", None) - return mm_inputs diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index b02d6df2..1691a001 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -871,6 +871,18 @@ register_template( ) +register_template( + name="granite3_vision", + format_user=StringFormatter(slots=["<|user|>\n{{content}}\n<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}\n"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + register_template( name="index", format_user=StringFormatter(slots=["reserved_0{{content}}reserved_1"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index e1c3aa43..025a8bbc 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -22,7 +22,7 @@ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME -AUDIO_PLACEHOLDER = os.environ.get("AUDIO_PLACEHOLDER", "