mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 09:52:14 +08:00 
			
		
		
		
	[misc] update format (#7277)
This commit is contained in:
		
							parent
							
								
									4b9d8da5a4
								
							
						
					
					
						commit
						650a9a9057
					
				
							
								
								
									
										2
									
								
								.github/ISSUE_TEMPLATE/1-bug-report.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ISSUE_TEMPLATE/1-bug-report.yml
									
									
									
									
										vendored
									
									
								
							@ -47,8 +47,6 @@ body:
 | 
			
		||||
      description: |
 | 
			
		||||
        Please provide entry arguments, error messages and stack traces that reproduces the problem.
 | 
			
		||||
        请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
 | 
			
		||||
        Remember to wrap your log messages with \`\`\`.
 | 
			
		||||
        请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。
 | 
			
		||||
 | 
			
		||||
      value: |
 | 
			
		||||
        ```text
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										4
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -58,6 +58,10 @@ jobs:
 | 
			
		||||
        run: |
 | 
			
		||||
          make style && make quality
 | 
			
		||||
 | 
			
		||||
      - name: Check license
 | 
			
		||||
        run: |
 | 
			
		||||
          make license
 | 
			
		||||
 | 
			
		||||
      - name: Test with pytest
 | 
			
		||||
        run: |
 | 
			
		||||
          make test
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								Makefile
									
									
									
									
									
								
							@ -1,4 +1,4 @@
 | 
			
		||||
.PHONY: build commit quality style test
 | 
			
		||||
.PHONY: build commit license quality style test
 | 
			
		||||
 | 
			
		||||
check_dirs := scripts src tests setup.py
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,9 @@ commit:
 | 
			
		||||
	pre-commit install
 | 
			
		||||
	pre-commit run --all-files
 | 
			
		||||
 | 
			
		||||
license:
 | 
			
		||||
	python3 tests/check_license.py $(check_dirs)
 | 
			
		||||
 | 
			
		||||
quality:
 | 
			
		||||
	ruff check $(check_dirs)
 | 
			
		||||
	ruff format --check $(check_dirs)
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,18 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,18 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,18 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,6 @@
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
 | 
			
		||||
from openai import OpenAI
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
@ -23,7 +22,7 @@ from transformers.utils.versions import require_version
 | 
			
		||||
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
 | 
			
		||||
def calculate_gpa(grades: list[str], hours: list[int]) -> float:
 | 
			
		||||
    grade_to_score = {"A": 4, "B": 3, "C": 2}
 | 
			
		||||
    total_score, total_hour = 0, 0
 | 
			
		||||
    for grade, hour in zip(grades, hours):
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Literal, Optional
 | 
			
		||||
 | 
			
		||||
@ -35,7 +34,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
 | 
			
		||||
    train_on_prompt: bool = False
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]:
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
 | 
			
		||||
        r"""Pad batched data to the longest sequence in the batch."""
 | 
			
		||||
        chosen_features = []
 | 
			
		||||
        for feature in features:
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from collections.abc import AsyncGenerator, Sequence
 | 
			
		||||
from collections.abc import AsyncGenerator
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
@ -63,12 +63,12 @@ class BaseEngine(ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> list["Response"]:
 | 
			
		||||
        r"""Get a list of responses of the chat model."""
 | 
			
		||||
@ -77,12 +77,12 @@ class BaseEngine(ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        r"""Get the response token-by-token of the chat model."""
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 THUDM and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 THUDM and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the THUDM's ChatGLM implementation.
 | 
			
		||||
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
 | 
			
		||||
@ -17,7 +17,7 @@
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import AsyncGenerator, Generator, Sequence
 | 
			
		||||
from collections.abc import AsyncGenerator, Generator
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
 | 
			
		||||
@ -61,12 +61,12 @@ class ChatModel:
 | 
			
		||||
 | 
			
		||||
    def chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> list["Response"]:
 | 
			
		||||
        r"""Get a list of responses of the chat model."""
 | 
			
		||||
@ -77,12 +77,12 @@ class ChatModel:
 | 
			
		||||
 | 
			
		||||
    async def achat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> list["Response"]:
 | 
			
		||||
        r"""Asynchronously get a list of responses of the chat model."""
 | 
			
		||||
@ -90,12 +90,12 @@ class ChatModel:
 | 
			
		||||
 | 
			
		||||
    def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> Generator[str, None, None]:
 | 
			
		||||
        r"""Get the response token-by-token of the chat model."""
 | 
			
		||||
@ -109,12 +109,12 @@ class ChatModel:
 | 
			
		||||
 | 
			
		||||
    async def astream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        r"""Asynchronously get the response token-by-token of the chat model."""
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
import concurrent.futures
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import AsyncGenerator, Sequence
 | 
			
		||||
from collections.abc import AsyncGenerator
 | 
			
		||||
from threading import Thread
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
 | 
			
		||||
 | 
			
		||||
@ -78,12 +78,12 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        template: "Template",
 | 
			
		||||
        generating_args: dict[str, Any],
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        input_kwargs: Optional[dict[str, Any]] = {},
 | 
			
		||||
    ) -> tuple[dict[str, Any], int]:
 | 
			
		||||
        mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
 | 
			
		||||
@ -219,12 +219,12 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        template: "Template",
 | 
			
		||||
        generating_args: dict[str, Any],
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        input_kwargs: Optional[dict[str, Any]] = {},
 | 
			
		||||
    ) -> list["Response"]:
 | 
			
		||||
        gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
 | 
			
		||||
@ -274,12 +274,12 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
        processor: Optional["ProcessorMixin"],
 | 
			
		||||
        template: "Template",
 | 
			
		||||
        generating_args: dict[str, Any],
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        input_kwargs: Optional[dict[str, Any]] = {},
 | 
			
		||||
    ) -> Callable[[], str]:
 | 
			
		||||
        gen_kwargs, _ = HuggingfaceEngine._process_args(
 | 
			
		||||
@ -338,12 +338,12 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
    @override
 | 
			
		||||
    async def chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> list["Response"]:
 | 
			
		||||
        if not self.can_generate:
 | 
			
		||||
@ -371,12 +371,12 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
    @override
 | 
			
		||||
    async def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        if not self.can_generate:
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import uuid
 | 
			
		||||
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
 | 
			
		||||
from collections.abc import AsyncGenerator, AsyncIterator
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional, Union
 | 
			
		||||
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
@ -102,12 +102,12 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
 | 
			
		||||
    async def _generate(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncIterator["RequestOutput"]:
 | 
			
		||||
        request_id = f"chatcmpl-{uuid.uuid4().hex}"
 | 
			
		||||
@ -202,12 +202,12 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
    @override
 | 
			
		||||
    async def chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> list["Response"]:
 | 
			
		||||
        final_output = None
 | 
			
		||||
@ -231,12 +231,12 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
    @override
 | 
			
		||||
    async def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
        images: Optional[Sequence["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[Sequence["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[Sequence["AudioInput"]] = None,
 | 
			
		||||
        images: Optional[list["ImageInput"]] = None,
 | 
			
		||||
        videos: Optional[list["VideoInput"]] = None,
 | 
			
		||||
        audios: Optional[list["AudioInput"]] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        generated_text = ""
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the OpenAccess AI Collective's axolotl library.
 | 
			
		||||
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
 | 
			
		||||
@ -15,7 +15,6 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Literal, Optional
 | 
			
		||||
 | 
			
		||||
@ -92,7 +91,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
 | 
			
		||||
        if self.template is None:
 | 
			
		||||
            raise ValueError("Template is required for MultiModalDataCollator.")
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        batch_images, batch_videos, batch_audios = [], [], []
 | 
			
		||||
        batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
 | 
			
		||||
        for feature in features:
 | 
			
		||||
@ -205,7 +204,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
    attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
 | 
			
		||||
    compute_dtype: "torch.dtype" = torch.float32
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        features = super().__call__(features)
 | 
			
		||||
        if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
 | 
			
		||||
            features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
 | 
			
		||||
@ -221,7 +220,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
    r"""Data collator for pairwise data."""
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        r"""Pad batched data to the longest sequence in the batch.
 | 
			
		||||
 | 
			
		||||
        We generate 2 * n examples where the first n examples represent chosen examples and
 | 
			
		||||
@ -247,7 +246,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
 | 
			
		||||
    r"""Data collator for KTO data."""
 | 
			
		||||
 | 
			
		||||
    def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
    def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        target_features = []
 | 
			
		||||
        kl_features = []
 | 
			
		||||
        kto_tags = []
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,6 @@
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional, Union
 | 
			
		||||
 | 
			
		||||
@ -37,7 +36,7 @@ class DatasetConverter:
 | 
			
		||||
    dataset_attr: "DatasetAttr"
 | 
			
		||||
    data_args: "DataArguments"
 | 
			
		||||
 | 
			
		||||
    def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]:
 | 
			
		||||
    def _find_medias(self, medias: Union[Any, list[Any]]) -> Optional[list[Any]]:
 | 
			
		||||
        r"""Optionally concatenate media path to media dir when loading from local disk."""
 | 
			
		||||
        if not isinstance(medias, list):
 | 
			
		||||
            medias = [medias] if medias is not None else []
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,6 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from enum import Enum, unique
 | 
			
		||||
from typing import TYPE_CHECKING, Optional, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
@ -30,7 +29,7 @@ if TYPE_CHECKING:
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
SLOTS = Sequence[Union[str, set[str], dict[str, str]]]
 | 
			
		||||
SLOTS = list[Union[str, set[str], dict[str, str]]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@unique
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING, Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
@ -157,7 +156,7 @@ def _load_single_dataset(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_merged_dataset(
 | 
			
		||||
    dataset_names: Optional[Sequence[str]],
 | 
			
		||||
    dataset_names: Optional[list[str]],
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,6 @@
 | 
			
		||||
import inspect
 | 
			
		||||
import math
 | 
			
		||||
import re
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from io import BytesIO
 | 
			
		||||
@ -83,9 +82,7 @@ if TYPE_CHECKING:
 | 
			
		||||
            pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_paligemma_token_type_ids(
 | 
			
		||||
    imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor"
 | 
			
		||||
) -> list[list[int]]:
 | 
			
		||||
def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
 | 
			
		||||
    r"""Get paligemma token type ids for computing loss.
 | 
			
		||||
 | 
			
		||||
    It is slightly different with the original token type ids where the prompt part is 0.
 | 
			
		||||
@ -120,7 +117,7 @@ def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcess
 | 
			
		||||
    return batch_token_type_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
 | 
			
		||||
def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
 | 
			
		||||
    r"""Make nested list of images."""
 | 
			
		||||
    batch_images = []
 | 
			
		||||
    for imglen in imglens:
 | 
			
		||||
@ -140,9 +137,9 @@ class MMPluginMixin:
 | 
			
		||||
    def _validate_input(
 | 
			
		||||
        self,
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        r"""Validate if this model accepts the input modalities."""
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
@ -202,7 +199,7 @@ class MMPluginMixin:
 | 
			
		||||
        sample_frames = min(total_frames, video_maxlen, sample_frames)
 | 
			
		||||
        return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
 | 
			
		||||
 | 
			
		||||
    def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]:
 | 
			
		||||
    def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]:
 | 
			
		||||
        r"""Regularize images to avoid error. Including reading and pre-processing."""
 | 
			
		||||
        results = []
 | 
			
		||||
        for image in images:
 | 
			
		||||
@ -223,7 +220,7 @@ class MMPluginMixin:
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
 | 
			
		||||
    def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
 | 
			
		||||
        r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
 | 
			
		||||
        results = []
 | 
			
		||||
        for video in videos:
 | 
			
		||||
@ -241,7 +238,7 @@ class MMPluginMixin:
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
 | 
			
		||||
    def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
 | 
			
		||||
        r"""Regularizes audios to avoid error. Including reading and resampling."""
 | 
			
		||||
        results = []
 | 
			
		||||
        for audio in audios:
 | 
			
		||||
@ -257,9 +254,9 @@ class MMPluginMixin:
 | 
			
		||||
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
        imglens: Optional[list[int]] = None,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
@ -335,10 +332,10 @@ class MMPluginMixin:
 | 
			
		||||
class BasePlugin(MMPluginMixin):
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        r"""Pre-process input messages before tokenization for VLMs."""
 | 
			
		||||
@ -349,9 +346,9 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: list[int],
 | 
			
		||||
        labels: Optional[list[int]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> tuple[list[int], Optional[list[int]]]:
 | 
			
		||||
@ -361,13 +358,13 @@ class BasePlugin(MMPluginMixin):
 | 
			
		||||
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        r"""Build batched multimodal inputs for VLMs.
 | 
			
		||||
@ -392,10 +389,10 @@ class Gemma3Plugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -420,13 +417,13 @@ class Gemma3Plugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -441,10 +438,10 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -481,10 +478,10 @@ class LlavaNextPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -523,10 +520,10 @@ class LlavaNextVideoPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -586,10 +583,10 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -686,9 +683,9 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
@ -757,13 +754,13 @@ class MiniCPMVPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -828,10 +825,10 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -850,13 +847,13 @@ class MllamaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -885,10 +882,10 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -912,9 +909,9 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: list[int],
 | 
			
		||||
        labels: Optional[list[int]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> tuple[list[int], Optional[list[int]]]:
 | 
			
		||||
@ -931,13 +928,13 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -952,10 +949,10 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -995,13 +992,13 @@ class PixtralPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -1015,10 +1012,10 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -1056,13 +1053,13 @@ class Qwen2AudioPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -1090,7 +1087,7 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _regularize_videos(
 | 
			
		||||
        self, videos: Sequence["VideoInput"], **kwargs
 | 
			
		||||
        self, videos: list["VideoInput"], **kwargs
 | 
			
		||||
    ) -> tuple[list[list["ImageObject"]], list[float]]:
 | 
			
		||||
        results, fps_per_video = [], []
 | 
			
		||||
        for video in videos:
 | 
			
		||||
@ -1118,9 +1115,9 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: "MMProcessor",
 | 
			
		||||
    ) -> dict[str, "torch.Tensor"]:
 | 
			
		||||
        image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
 | 
			
		||||
@ -1149,10 +1146,10 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -1204,13 +1201,13 @@ class Qwen2VLPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        imglens: Sequence[int],
 | 
			
		||||
        vidlens: Sequence[int],
 | 
			
		||||
        audlens: Sequence[int],
 | 
			
		||||
        batch_ids: Sequence[list[int]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
@ -1229,10 +1226,10 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,6 @@
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Literal, Optional
 | 
			
		||||
 | 
			
		||||
@ -91,7 +90,7 @@ class DatasetAttr:
 | 
			
		||||
                self.set_attr(tag, attr["tags"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]:
 | 
			
		||||
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
 | 
			
		||||
    r"""Get the attributes of the datasets."""
 | 
			
		||||
    if dataset_names is None:
 | 
			
		||||
        dataset_names = []
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,17 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from .feedback import FeedbackDatasetProcessor
 | 
			
		||||
from .pairwise import PairwiseDatasetProcessor
 | 
			
		||||
from .pretrain import PretrainDatasetProcessor
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
@ -31,14 +30,14 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
class FeedbackDatasetProcessor(DatasetProcessor):
 | 
			
		||||
    def _encode_data_example(
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: Sequence[dict[str, str]],
 | 
			
		||||
        response: Sequence[dict[str, str]],
 | 
			
		||||
        kl_response: Sequence[dict[str, str]],
 | 
			
		||||
        prompt: list[dict[str, str]],
 | 
			
		||||
        response: list[dict[str, str]],
 | 
			
		||||
        kl_response: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str],
 | 
			
		||||
        tools: Optional[str],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
    ) -> tuple[list[int], list[int], list[int], list[int], bool]:
 | 
			
		||||
        if response[0]["content"]:  # desired example
 | 
			
		||||
            kto_tag = True
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
@ -31,13 +30,13 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
class PairwiseDatasetProcessor(DatasetProcessor):
 | 
			
		||||
    def _encode_data_example(
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: Sequence[dict[str, str]],
 | 
			
		||||
        response: Sequence[dict[str, str]],
 | 
			
		||||
        prompt: list[dict[str, str]],
 | 
			
		||||
        response: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str],
 | 
			
		||||
        tools: Optional[str],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
    ) -> tuple[list[int], list[int], list[int], list[int]]:
 | 
			
		||||
        chosen_messages = self.template.mm_plugin.process_messages(
 | 
			
		||||
            prompt + [response[0]], images, videos, audios, self.processor
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,6 @@
 | 
			
		||||
 | 
			
		||||
import bisect
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
 | 
			
		||||
@ -46,7 +45,7 @@ class DatasetProcessor(ABC):
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def search_for_fit(numbers: Sequence[int], capacity: int) -> int:
 | 
			
		||||
def search_for_fit(numbers: list[int], capacity: int) -> int:
 | 
			
		||||
    r"""Find the index of largest number that fits into the knapsack with the given capacity."""
 | 
			
		||||
    index = bisect.bisect(numbers, capacity)
 | 
			
		||||
    return -1 if index == 0 else (index - 1)
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
 | 
			
		||||
@ -33,13 +32,13 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
class SupervisedDatasetProcessor(DatasetProcessor):
 | 
			
		||||
    def _encode_data_example(
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: Sequence[dict[str, str]],
 | 
			
		||||
        response: Sequence[dict[str, str]],
 | 
			
		||||
        prompt: list[dict[str, str]],
 | 
			
		||||
        response: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str],
 | 
			
		||||
        tools: Optional[str],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
    ) -> tuple[list[int], list[int]]:
 | 
			
		||||
        messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
 | 
			
		||||
        input_ids, labels = self.template.mm_plugin.process_token_ids(
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Optional
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
@ -31,13 +30,13 @@ logger = logging.get_logger(__name__)
 | 
			
		||||
class UnsupervisedDatasetProcessor(DatasetProcessor):
 | 
			
		||||
    def _encode_data_example(
 | 
			
		||||
        self,
 | 
			
		||||
        prompt: Sequence[dict[str, str]],
 | 
			
		||||
        response: Sequence[dict[str, str]],
 | 
			
		||||
        prompt: list[dict[str, str]],
 | 
			
		||||
        response: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str],
 | 
			
		||||
        tools: Optional[str],
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
        videos: Sequence["VideoInput"],
 | 
			
		||||
        audios: Sequence["AudioInput"],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
    ) -> tuple[list[int], list[int]]:
 | 
			
		||||
        if len(response) == 1:
 | 
			
		||||
            messages = prompt + response
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,6 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Optional, Union
 | 
			
		||||
 | 
			
		||||
@ -57,7 +56,7 @@ class Template:
 | 
			
		||||
    def encode_oneturn(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
    ) -> tuple[list[int], list[int]]:
 | 
			
		||||
@ -73,7 +72,7 @@ class Template:
 | 
			
		||||
    def encode_multiturn(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
    ) -> list[tuple[list[int], list[int]]]:
 | 
			
		||||
@ -115,7 +114,7 @@ class Template:
 | 
			
		||||
    def _encode(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: Optional[str],
 | 
			
		||||
        tools: Optional[str],
 | 
			
		||||
    ) -> list[list[int]]:
 | 
			
		||||
@ -316,7 +315,7 @@ class Llama2Template(Template):
 | 
			
		||||
    def _encode(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
        messages: Sequence[dict[str, str]],
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        system: str,
 | 
			
		||||
        tools: str,
 | 
			
		||||
    ) -> list[list[int]]:
 | 
			
		||||
@ -391,7 +390,7 @@ def register_template(
 | 
			
		||||
    format_tools: Optional["Formatter"] = None,
 | 
			
		||||
    format_prefix: Optional["Formatter"] = None,
 | 
			
		||||
    default_system: str = "",
 | 
			
		||||
    stop_words: Optional[Sequence[str]] = None,
 | 
			
		||||
    stop_words: Optional[list[str]] = None,
 | 
			
		||||
    thought_words: Optional[tuple[str, str]] = None,
 | 
			
		||||
    efficient_eos: bool = False,
 | 
			
		||||
    replace_eos: bool = False,
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,6 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
from ..data import Role
 | 
			
		||||
@ -35,7 +34,7 @@ class EvalTemplate:
 | 
			
		||||
        return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
 | 
			
		||||
 | 
			
		||||
    def format_example(
 | 
			
		||||
        self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str
 | 
			
		||||
        self, target_data: dict[str, str], support_set: list[dict[str, str]], subject_name: str
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        r"""Convert dataset examples to messages."""
 | 
			
		||||
        messages = []
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's PEFT library.
 | 
			
		||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
 | 
			
		||||
@ -17,7 +17,6 @@
 | 
			
		||||
 | 
			
		||||
import gc
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Literal, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -98,7 +97,7 @@ def check_dependencies() -> None:
 | 
			
		||||
        logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
 | 
			
		||||
def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
 | 
			
		||||
    r"""Calculate effective tokens per second."""
 | 
			
		||||
    effective_token_num = 0
 | 
			
		||||
    for data in dataset:
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,17 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Literal, Optional, Union
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's Transformers and PEFT library,
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,6 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -27,7 +26,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from ...hparams import ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None:
 | 
			
		||||
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list["torch.nn.Module"]) -> None:
 | 
			
		||||
    check_version("deepspeed>=0.13.0")
 | 
			
		||||
    from deepspeed.utils import set_z3_leaf_modules  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 Musab Gultekin and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 Musab Gultekin and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is based on the Musab Gultekin's functionary library.
 | 
			
		||||
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's Transformers and Optimum library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 LMSYS and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 LMSYS and the LlamaFactory team.
 | 
			
		||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the LMSYS's FastChat library.
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,6 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Optional
 | 
			
		||||
 | 
			
		||||
@ -180,7 +179,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_target_modules(
 | 
			
		||||
    model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
 | 
			
		||||
    model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: list[str]
 | 
			
		||||
) -> list[str]:
 | 
			
		||||
    r"""Freeze vision tower for VLM LoRA tuning."""
 | 
			
		||||
    model_type = getattr(model.config, "model_type", None)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's TRL library.
 | 
			
		||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's TRL library.
 | 
			
		||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's TRL library.
 | 
			
		||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's TRL library.
 | 
			
		||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's TRL library.
 | 
			
		||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's TRL library.
 | 
			
		||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
 | 
			
		||||
@ -37,11 +37,11 @@ if is_jieba_available():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_nltk_available():
 | 
			
		||||
    from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
 | 
			
		||||
    from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_rouge_available():
 | 
			
		||||
    from rouge_chinese import Rouge
 | 
			
		||||
    from rouge_chinese import Rouge  # type: ignore
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,6 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING, Optional, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -33,7 +32,7 @@ if TYPE_CHECKING:
 | 
			
		||||
    from ..data.data_utils import DatasetModule
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None:
 | 
			
		||||
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: list[str] = []) -> None:
 | 
			
		||||
    state_dict_a = model_a.state_dict()
 | 
			
		||||
    state_dict_b = model_b.state_dict()
 | 
			
		||||
    assert set(state_dict_a.keys()) == set(state_dict_b.keys())
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
 | 
			
		||||
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										38
									
								
								tests/check_license.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								tests/check_license.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,38 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
KEYWORDS = ("Copyright", "2025", "LlamaFactory")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    path_list = []
 | 
			
		||||
    for check_dir in sys.argv[1:]:
 | 
			
		||||
        path_list.extend(Path(check_dir).glob("**/*.py"))
 | 
			
		||||
 | 
			
		||||
    for path in path_list:
 | 
			
		||||
        with open(path.absolute(), encoding="utf-8") as f:
 | 
			
		||||
            file_content = f.read().strip().split("\n")
 | 
			
		||||
            if not file_content[0]:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            print(f"Check license: {path}")
 | 
			
		||||
            assert all(keyword in file_content[0] for keyword in KEYWORDS), f"File {path} does not contain license."
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING, Any
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
@ -97,7 +96,7 @@ def _check_plugin(
 | 
			
		||||
    plugin: "BasePlugin",
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: "ProcessorMixin",
 | 
			
		||||
    expected_mm_messages: Sequence[dict[str, str]] = MM_MESSAGES,
 | 
			
		||||
    expected_mm_messages: list[dict[str, str]] = MM_MESSAGES,
 | 
			
		||||
    expected_input_ids: list[int] = INPUT_IDS,
 | 
			
		||||
    expected_labels: list[int] = LABELS,
 | 
			
		||||
    expected_mm_inputs: dict[str, Any] = {},
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
from collections.abc import Sequence
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
@ -41,7 +40,7 @@ MESSAGES = [
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_tokenization(
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str]
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
 | 
			
		||||
) -> None:
 | 
			
		||||
    r"""Check token ids and texts.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user