[misc] update format (#7277)

This commit is contained in:
hoshi-hiyouga 2025-03-13 02:53:08 +08:00 committed by GitHub
parent 165d3ed084
commit 9ccfb97a2c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
62 changed files with 384 additions and 288 deletions

View File

@ -47,8 +47,6 @@ body:
description: | description: |
Please provide entry arguments, error messages and stack traces that reproduces the problem. Please provide entry arguments, error messages and stack traces that reproduces the problem.
请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。 请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。
Remember to wrap your log messages with \`\`\`.
请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。
value: | value: |
```text ```text

View File

@ -58,6 +58,10 @@ jobs:
run: | run: |
make style && make quality make style && make quality
- name: Check license
run: |
make license
- name: Test with pytest - name: Test with pytest
run: | run: |
make test make test

View File

@ -1,4 +1,4 @@
.PHONY: build commit quality style test .PHONY: build commit license quality style test
check_dirs := scripts src tests setup.py check_dirs := scripts src tests setup.py
@ -9,6 +9,9 @@ commit:
pre-commit install pre-commit install
pre-commit run --all-files pre-commit run --all-files
license:
python3 tests/check_license.py $(check_dirs)
quality: quality:
ruff check $(check_dirs) ruff check $(check_dirs)
ruff format --check $(check_dirs) ruff format --check $(check_dirs)

View File

@ -1,3 +1,18 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os

View File

@ -1,3 +1,18 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os

View File

@ -1,3 +1,18 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
import os import os

View File

@ -1,3 +1,4 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,3 +1,4 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,3 +1,4 @@
# Copyright 2025 the LlamaFactory team.
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -14,7 +14,6 @@
import json import json
import os import os
from collections.abc import Sequence
from openai import OpenAI from openai import OpenAI
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -23,7 +22,7 @@ from transformers.utils.versions import require_version
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float: def calculate_gpa(grades: list[str], hours: list[int]) -> float:
grade_to_score = {"A": 4, "B": 3, "C": 2} grade_to_score = {"A": 4, "B": 3, "C": 2}
total_score, total_hour = 0, 0 total_score, total_hour = 0, 0
for grade, hour in zip(grades, hours): for grade, hour in zip(grades, hours):

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import json import json
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
@ -35,7 +34,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
train_on_prompt: bool = False train_on_prompt: bool = False
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
r"""Pad batched data to the longest sequence in the batch.""" r"""Pad batched data to the longest sequence in the batch."""
chosen_features = [] chosen_features = []
for feature in features: for feature in features:

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence from collections.abc import AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Literal, Optional, Union
@ -63,12 +63,12 @@ class BaseEngine(ABC):
@abstractmethod @abstractmethod
async def chat( async def chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> list["Response"]: ) -> list["Response"]:
r"""Get a list of responses of the chat model.""" r"""Get a list of responses of the chat model."""
@ -77,12 +77,12 @@ class BaseEngine(ABC):
@abstractmethod @abstractmethod
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r"""Get the response token-by-token of the chat model.""" r"""Get the response token-by-token of the chat model."""

View File

@ -1,4 +1,4 @@
# Copyright 2024 THUDM and the LlamaFactory team. # Copyright 2025 THUDM and the LlamaFactory team.
# #
# This code is inspired by the THUDM's ChatGLM implementation. # This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
@ -17,7 +17,7 @@
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator, Generator, Sequence from collections.abc import AsyncGenerator, Generator
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -61,12 +61,12 @@ class ChatModel:
def chat( def chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> list["Response"]: ) -> list["Response"]:
r"""Get a list of responses of the chat model.""" r"""Get a list of responses of the chat model."""
@ -77,12 +77,12 @@ class ChatModel:
async def achat( async def achat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> list["Response"]: ) -> list["Response"]:
r"""Asynchronously get a list of responses of the chat model.""" r"""Asynchronously get a list of responses of the chat model."""
@ -90,12 +90,12 @@ class ChatModel:
def stream_chat( def stream_chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r"""Get the response token-by-token of the chat model.""" r"""Get the response token-by-token of the chat model."""
@ -109,12 +109,12 @@ class ChatModel:
async def astream_chat( async def astream_chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r"""Asynchronously get the response token-by-token of the chat model.""" r"""Asynchronously get the response token-by-token of the chat model."""

View File

@ -15,7 +15,7 @@
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import os import os
from collections.abc import AsyncGenerator, Sequence from collections.abc import AsyncGenerator
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, Callable, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Optional, Union
@ -78,12 +78,12 @@ class HuggingfaceEngine(BaseEngine):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> tuple[dict[str, Any], int]: ) -> tuple[dict[str, Any], int]:
mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]}
@ -219,12 +219,12 @@ class HuggingfaceEngine(BaseEngine):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> list["Response"]: ) -> list["Response"]:
gen_kwargs, prompt_length = HuggingfaceEngine._process_args( gen_kwargs, prompt_length = HuggingfaceEngine._process_args(
@ -274,12 +274,12 @@ class HuggingfaceEngine(BaseEngine):
processor: Optional["ProcessorMixin"], processor: Optional["ProcessorMixin"],
template: "Template", template: "Template",
generating_args: dict[str, Any], generating_args: dict[str, Any],
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
input_kwargs: Optional[dict[str, Any]] = {}, input_kwargs: Optional[dict[str, Any]] = {},
) -> Callable[[], str]: ) -> Callable[[], str]:
gen_kwargs, _ = HuggingfaceEngine._process_args( gen_kwargs, _ = HuggingfaceEngine._process_args(
@ -338,12 +338,12 @@ class HuggingfaceEngine(BaseEngine):
@override @override
async def chat( async def chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> list["Response"]: ) -> list["Response"]:
if not self.can_generate: if not self.can_generate:
@ -371,12 +371,12 @@ class HuggingfaceEngine(BaseEngine):
@override @override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if not self.can_generate: if not self.can_generate:

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import uuid import uuid
from collections.abc import AsyncGenerator, AsyncIterator, Sequence from collections.abc import AsyncGenerator, AsyncIterator
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
from typing_extensions import override from typing_extensions import override
@ -102,12 +102,12 @@ class VllmEngine(BaseEngine):
async def _generate( async def _generate(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncIterator["RequestOutput"]: ) -> AsyncIterator["RequestOutput"]:
request_id = f"chatcmpl-{uuid.uuid4().hex}" request_id = f"chatcmpl-{uuid.uuid4().hex}"
@ -202,12 +202,12 @@ class VllmEngine(BaseEngine):
@override @override
async def chat( async def chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> list["Response"]: ) -> list["Response"]:
final_output = None final_output = None
@ -231,12 +231,12 @@ class VllmEngine(BaseEngine):
@override @override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
images: Optional[Sequence["ImageInput"]] = None, images: Optional[list["ImageInput"]] = None,
videos: Optional[Sequence["VideoInput"]] = None, videos: Optional[list["VideoInput"]] = None,
audios: Optional[Sequence["AudioInput"]] = None, audios: Optional[list["AudioInput"]] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
generated_text = "" generated_text = ""

View File

@ -1,4 +1,4 @@
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. # Copyright 2025 OpenAccess AI Collective and the LlamaFactory team.
# #
# This code is inspired by the OpenAccess AI Collective's axolotl library. # This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py # https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
@ -92,7 +91,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if self.template is None: if self.template is None:
raise ValueError("Template is required for MultiModalDataCollator.") raise ValueError("Template is required for MultiModalDataCollator.")
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
batch_images, batch_videos, batch_audios = [], [], [] batch_images, batch_videos, batch_audios = [], [], []
batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], []
for feature in features: for feature in features:
@ -205,7 +204,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager"
compute_dtype: "torch.dtype" = torch.float32 compute_dtype: "torch.dtype" = torch.float32
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
features = super().__call__(features) features = super().__call__(features)
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": if self.block_diag_attn and self.attn_implementation != "flash_attention_2":
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)
@ -221,7 +220,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""Data collator for pairwise data.""" r"""Data collator for pairwise data."""
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
r"""Pad batched data to the longest sequence in the batch. r"""Pad batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and We generate 2 * n examples where the first n examples represent chosen examples and
@ -247,7 +246,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
r"""Data collator for KTO data.""" r"""Data collator for KTO data."""
def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]:
target_features = [] target_features = []
kl_features = [] kl_features = []
kto_tags = [] kto_tags = []

View File

@ -14,7 +14,6 @@
import os import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
@ -37,7 +36,7 @@ class DatasetConverter:
dataset_attr: "DatasetAttr" dataset_attr: "DatasetAttr"
data_args: "DataArguments" data_args: "DataArguments"
def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]: def _find_medias(self, medias: Union[Any, list[Any]]) -> Optional[list[Any]]:
r"""Optionally concatenate media path to media dir when loading from local disk.""" r"""Optionally concatenate media path to media dir when loading from local disk."""
if not isinstance(medias, list): if not isinstance(medias, list):
medias = [medias] if medias is not None else [] medias = [medias] if medias is not None else []

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from enum import Enum, unique from enum import Enum, unique
from typing import TYPE_CHECKING, Optional, TypedDict, Union from typing import TYPE_CHECKING, Optional, TypedDict, Union
@ -30,7 +29,7 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
SLOTS = Sequence[Union[str, set[str], dict[str, str]]] SLOTS = list[Union[str, set[str], dict[str, str]]]
@unique @unique

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal, Optional, Union from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np import numpy as np
@ -157,7 +156,7 @@ def _load_single_dataset(
def _get_merged_dataset( def _get_merged_dataset(
dataset_names: Optional[Sequence[str]], dataset_names: Optional[list[str]],
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",

View File

@ -18,7 +18,6 @@
import inspect import inspect
import math import math
import re import re
from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
@ -83,9 +82,7 @@ if TYPE_CHECKING:
pass pass
def _get_paligemma_token_type_ids( def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]:
imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor"
) -> list[list[int]]:
r"""Get paligemma token type ids for computing loss. r"""Get paligemma token type ids for computing loss.
It is slightly different with the original token type ids where the prompt part is 0. It is slightly different with the original token type ids where the prompt part is 0.
@ -120,7 +117,7 @@ def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcess
return batch_token_type_ids return batch_token_type_ids
def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]: def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]:
r"""Make nested list of images.""" r"""Make nested list of images."""
batch_images = [] batch_images = []
for imglen in imglens: for imglen in imglens:
@ -140,9 +137,9 @@ class MMPluginMixin:
def _validate_input( def _validate_input(
self, self,
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> None: ) -> None:
r"""Validate if this model accepts the input modalities.""" r"""Validate if this model accepts the input modalities."""
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
@ -202,7 +199,7 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames) sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]: def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]:
r"""Regularize images to avoid error. Including reading and pre-processing.""" r"""Regularize images to avoid error. Including reading and pre-processing."""
results = [] results = []
for image in images: for image in images:
@ -223,7 +220,7 @@ class MMPluginMixin:
return results return results
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]: def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]:
r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = [] results = []
for video in videos: for video in videos:
@ -241,7 +238,7 @@ class MMPluginMixin:
return results return results
def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]: def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]:
r"""Regularizes audios to avoid error. Including reading and resampling.""" r"""Regularizes audios to avoid error. Including reading and resampling."""
results = [] results = []
for audio in audios: for audio in audios:
@ -257,9 +254,9 @@ class MMPluginMixin:
def _get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: "MMProcessor", processor: "MMProcessor",
imglens: Optional[list[int]] = None, imglens: Optional[list[int]] = None,
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
@ -335,10 +332,10 @@ class MMPluginMixin:
class BasePlugin(MMPluginMixin): class BasePlugin(MMPluginMixin):
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
r"""Pre-process input messages before tokenization for VLMs.""" r"""Pre-process input messages before tokenization for VLMs."""
@ -349,9 +346,9 @@ class BasePlugin(MMPluginMixin):
self, self,
input_ids: list[int], input_ids: list[int],
labels: Optional[list[int]], labels: Optional[list[int]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]: ) -> tuple[list[int], Optional[list[int]]]:
@ -361,13 +358,13 @@ class BasePlugin(MMPluginMixin):
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[list[int]], batch_ids: list[list[int]],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
r"""Build batched multimodal inputs for VLMs. r"""Build batched multimodal inputs for VLMs.
@ -392,10 +389,10 @@ class Gemma3Plugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -420,13 +417,13 @@ class Gemma3Plugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[list[int]], batch_ids: list[list[int]],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -441,10 +438,10 @@ class LlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -481,10 +478,10 @@ class LlavaNextPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -523,10 +520,10 @@ class LlavaNextVideoPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -586,10 +583,10 @@ class MiniCPMVPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -686,9 +683,9 @@ class MiniCPMVPlugin(BasePlugin):
@override @override
def _get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: "MMProcessor", processor: "MMProcessor",
**kwargs, **kwargs,
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
@ -757,13 +754,13 @@ class MiniCPMVPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[list[int]], batch_ids: list[list[int]],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -828,10 +825,10 @@ class MllamaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -850,13 +847,13 @@ class MllamaPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[list[int]], batch_ids: list[list[int]],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -885,10 +882,10 @@ class PaliGemmaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -912,9 +909,9 @@ class PaliGemmaPlugin(BasePlugin):
self, self,
input_ids: list[int], input_ids: list[int],
labels: Optional[list[int]], labels: Optional[list[int]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> tuple[list[int], Optional[list[int]]]: ) -> tuple[list[int], Optional[list[int]]]:
@ -931,13 +928,13 @@ class PaliGemmaPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[list[int]], batch_ids: list[list[int]],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -952,10 +949,10 @@ class PixtralPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -995,13 +992,13 @@ class PixtralPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[list[int]], batch_ids: list[list[int]],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -1015,10 +1012,10 @@ class Qwen2AudioPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -1056,13 +1053,13 @@ class Qwen2AudioPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[list[int]], batch_ids: list[list[int]],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -1090,7 +1087,7 @@ class Qwen2VLPlugin(BasePlugin):
@override @override
def _regularize_videos( def _regularize_videos(
self, videos: Sequence["VideoInput"], **kwargs self, videos: list["VideoInput"], **kwargs
) -> tuple[list[list["ImageObject"]], list[float]]: ) -> tuple[list[list["ImageObject"]], list[float]]:
results, fps_per_video = [], [] results, fps_per_video = [], []
for video in videos: for video in videos:
@ -1118,9 +1115,9 @@ class Qwen2VLPlugin(BasePlugin):
@override @override
def _get_mm_inputs( def _get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: "MMProcessor", processor: "MMProcessor",
) -> dict[str, "torch.Tensor"]: ) -> dict[str, "torch.Tensor"]:
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
@ -1149,10 +1146,10 @@ class Qwen2VLPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -1204,13 +1201,13 @@ class Qwen2VLPlugin(BasePlugin):
@override @override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
imglens: Sequence[int], imglens: list[int],
vidlens: Sequence[int], vidlens: list[int],
audlens: Sequence[int], audlens: list[int],
batch_ids: Sequence[list[int]], batch_ids: list[list[int]],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]: ) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)
@ -1229,10 +1226,10 @@ class VideoLlavaPlugin(BasePlugin):
@override @override
def process_messages( def process_messages(
self, self,
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
processor: Optional["MMProcessor"], processor: Optional["MMProcessor"],
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios) self._validate_input(processor, images, videos, audios)

View File

@ -14,7 +14,6 @@
import json import json
import os import os
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
@ -91,7 +90,7 @@ class DatasetAttr:
self.set_attr(tag, attr["tags"]) self.set_attr(tag, attr["tags"])
def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]: def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]:
r"""Get the attributes of the datasets.""" r"""Get the attributes of the datasets."""
if dataset_names is None: if dataset_names is None:
dataset_names = [] dataset_names = []

View File

@ -1,3 +1,17 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .feedback import FeedbackDatasetProcessor from .feedback import FeedbackDatasetProcessor
from .pairwise import PairwiseDatasetProcessor from .pairwise import PairwiseDatasetProcessor
from .pretrain import PretrainDatasetProcessor from .pretrain import PretrainDatasetProcessor

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
@ -31,14 +30,14 @@ logger = logging.get_logger(__name__)
class FeedbackDatasetProcessor(DatasetProcessor): class FeedbackDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[dict[str, str]], response: list[dict[str, str]],
kl_response: Sequence[dict[str, str]], kl_response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> tuple[list[int], list[int], list[int], list[int], bool]: ) -> tuple[list[int], list[int], list[int], list[int], bool]:
if response[0]["content"]: # desired example if response[0]["content"]: # desired example
kto_tag = True kto_tag = True

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
@ -31,13 +30,13 @@ logger = logging.get_logger(__name__)
class PairwiseDatasetProcessor(DatasetProcessor): class PairwiseDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> tuple[list[int], list[int], list[int], list[int]]: ) -> tuple[list[int], list[int], list[int], list[int]]:
chosen_messages = self.template.mm_plugin.process_messages( chosen_messages = self.template.mm_plugin.process_messages(
prompt + [response[0]], images, videos, audios, self.processor prompt + [response[0]], images, videos, audios, self.processor

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py

View File

@ -14,7 +14,6 @@
import bisect import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -46,7 +45,7 @@ class DatasetProcessor(ABC):
... ...
def search_for_fit(numbers: Sequence[int], capacity: int) -> int: def search_for_fit(numbers: list[int], capacity: int) -> int:
r"""Find the index of largest number that fits into the knapsack with the given capacity.""" r"""Find the index of largest number that fits into the knapsack with the given capacity."""
index = bisect.bisect(numbers, capacity) index = bisect.bisect(numbers, capacity)
return -1 if index == 0 else (index - 1) return -1 if index == 0 else (index - 1)

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@ -33,13 +32,13 @@ logger = logging.get_logger(__name__)
class SupervisedDatasetProcessor(DatasetProcessor): class SupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor)
input_ids, labels = self.template.mm_plugin.process_token_ids( input_ids, labels = self.template.mm_plugin.process_token_ids(

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from collections import defaultdict from collections import defaultdict
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from ...extras import logging from ...extras import logging
@ -31,13 +30,13 @@ logger = logging.get_logger(__name__)
class UnsupervisedDatasetProcessor(DatasetProcessor): class UnsupervisedDatasetProcessor(DatasetProcessor):
def _encode_data_example( def _encode_data_example(
self, self,
prompt: Sequence[dict[str, str]], prompt: list[dict[str, str]],
response: Sequence[dict[str, str]], response: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
images: Sequence["ImageInput"], images: list["ImageInput"],
videos: Sequence["VideoInput"], videos: list["VideoInput"],
audios: Sequence["AudioInput"], audios: list["AudioInput"],
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
if len(response) == 1: if len(response) == 1:
messages = prompt + response messages = prompt + response

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
@ -57,7 +56,7 @@ class Template:
def encode_oneturn( def encode_oneturn(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
@ -73,7 +72,7 @@ class Template:
def encode_multiturn( def encode_multiturn(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str] = None, system: Optional[str] = None,
tools: Optional[str] = None, tools: Optional[str] = None,
) -> list[tuple[list[int], list[int]]]: ) -> list[tuple[list[int], list[int]]]:
@ -115,7 +114,7 @@ class Template:
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: Optional[str], system: Optional[str],
tools: Optional[str], tools: Optional[str],
) -> list[list[int]]: ) -> list[list[int]]:
@ -316,7 +315,7 @@ class Llama2Template(Template):
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
messages: Sequence[dict[str, str]], messages: list[dict[str, str]],
system: str, system: str,
tools: str, tools: str,
) -> list[list[int]]: ) -> list[list[int]]:
@ -391,7 +390,7 @@ def register_template(
format_tools: Optional["Formatter"] = None, format_tools: Optional["Formatter"] = None,
format_prefix: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None,
default_system: str = "", default_system: str = "",
stop_words: Optional[Sequence[str]] = None, stop_words: Optional[list[str]] = None,
thought_words: Optional[tuple[str, str]] = None, thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False, efficient_eos: bool = False,
replace_eos: bool = False, replace_eos: bool = False,

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from ..data import Role from ..data import Role
@ -35,7 +34,7 @@ class EvalTemplate:
return "".join([example["question"]] + candidates + [self.answer]), example["answer"] return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
def format_example( def format_example(
self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str self, target_data: dict[str, str], support_set: list[dict[str, str]], subject_name: str
) -> list[dict[str, str]]: ) -> list[dict[str, str]]:
r"""Convert dataset examples to messages.""" r"""Convert dataset examples to messages."""
messages = [] messages = []

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's PEFT library. # This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
@ -17,7 +17,6 @@
import gc import gc
import os import os
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Literal, Union from typing import TYPE_CHECKING, Any, Literal, Union
import torch import torch
@ -98,7 +97,7 @@ def check_dependencies() -> None:
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float: def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float:
r"""Calculate effective tokens per second.""" r"""Calculate effective tokens per second."""
effective_token_num = 0 effective_token_num = 0
for data in dataset: for data in dataset:

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py

View File

@ -1,3 +1,17 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json import json
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Union from typing import Literal, Optional, Union

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team. # Copyright 2025 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's Transformers and PEFT library, # This code is inspired by the HuggingFace's Transformers and PEFT library,
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team. # Copyright 2025 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
# #
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries. # This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
@ -27,7 +26,7 @@ if TYPE_CHECKING:
from ...hparams import ModelArguments from ...hparams import ModelArguments
def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list["torch.nn.Module"]) -> None:
check_version("deepspeed>=0.13.0") check_version("deepspeed>=0.13.0")
from deepspeed.utils import set_z3_leaf_modules # type: ignore from deepspeed.utils import set_z3_leaf_modules # type: ignore

View File

@ -1,4 +1,4 @@
# Copyright 2024 Musab Gultekin and the LlamaFactory team. # Copyright 2025 Musab Gultekin and the LlamaFactory team.
# #
# This code is based on the Musab Gultekin's functionary library. # This code is based on the Musab Gultekin's functionary library.
# https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py # https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's Transformers and Optimum library. # This code is inspired by the HuggingFace's Transformers and Optimum library.
# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py # https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 LMSYS and the LlamaFactory team. # Copyright 2025 LMSYS and the LlamaFactory team.
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
# #
# This code is inspired by the LMSYS's FastChat library. # This code is inspired by the LMSYS's FastChat library.

View File

@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
@ -180,7 +179,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
def patch_target_modules( def patch_target_modules(
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: list[str]
) -> list[str]: ) -> list[str]:
r"""Freeze vision tower for VLM LoRA tuning.""" r"""Freeze vision tower for VLM LoRA tuning."""
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team. # Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation. # This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
@ -37,11 +37,11 @@ if is_jieba_available():
if is_nltk_available(): if is_nltk_available():
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore
if is_rouge_available(): if is_rouge_available():
from rouge_chinese import Rouge from rouge_chinese import Rouge # type: ignore
def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
import torch import torch
@ -33,7 +32,7 @@ if TYPE_CHECKING:
from ..data.data_utils import DatasetModule from ..data.data_utils import DatasetModule
def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None: def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: list[str] = []) -> None:
state_dict_a = model_a.state_dict() state_dict_a = model_a.state_dict()
state_dict_b = model_b.state_dict() state_dict_b = model_b.state_dict()
assert set(state_dict_a.keys()) == set(state_dict_b.keys()) assert set(state_dict_a.keys()) == set(state_dict_b.keys())

View File

@ -1,4 +1,4 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore # This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore
# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus # and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus

38
tests/check_license.py Normal file
View File

@ -0,0 +1,38 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from pathlib import Path
KEYWORDS = ("Copyright", "2025", "LlamaFactory")
def main():
path_list = []
for check_dir in sys.argv[1:]:
path_list.extend(Path(check_dir).glob("**/*.py"))
for path in path_list:
with open(path.absolute(), encoding="utf-8") as f:
file_content = f.read().strip().split("\n")
if not file_content[0]:
continue
print(f"Check license: {path}")
assert all(keyword in file_content[0] for keyword in KEYWORDS), f"File {path} does not contain license."
if __name__ == "__main__":
main()

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import pytest import pytest
@ -97,7 +96,7 @@ def _check_plugin(
plugin: "BasePlugin", plugin: "BasePlugin",
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: "ProcessorMixin", processor: "ProcessorMixin",
expected_mm_messages: Sequence[dict[str, str]] = MM_MESSAGES, expected_mm_messages: list[dict[str, str]] = MM_MESSAGES,
expected_input_ids: list[int] = INPUT_IDS, expected_input_ids: list[int] = INPUT_IDS,
expected_labels: list[int] = LABELS, expected_labels: list[int] = LABELS,
expected_mm_inputs: dict[str, Any] = {}, expected_mm_inputs: dict[str, Any] = {},

View File

@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
from collections.abc import Sequence
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import pytest import pytest
@ -41,7 +40,7 @@ MESSAGES = [
def _check_tokenization( def _check_tokenization(
tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str] tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str]
) -> None: ) -> None:
r"""Check token ids and texts. r"""Check token ids and texts.