From 9ccfb97a2c06cedf54ae0d82db62a9b068be9637 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Thu, 13 Mar 2025 02:53:08 +0800 Subject: [PATCH] [misc] update format (#7277) --- .github/ISSUE_TEMPLATE/1-bug-report.yml | 2 - .github/workflows/tests.yml | 4 + Makefile | 5 +- data/belle_multiturn/belle_multiturn.py | 15 + data/hh_rlhf_en/hh_rlhf_en.py | 15 + data/ultra_chat/ultra_chat.py | 15 + evaluation/ceval/ceval.py | 1 + evaluation/cmmlu/cmmlu.py | 1 + evaluation/mmlu/mmlu.py | 1 + scripts/api_example/test_toolcall.py | 3 +- scripts/stat_utils/cal_ppl.py | 3 +- src/llamafactory/chat/base_engine.py | 18 +- src/llamafactory/chat/chat_model.py | 36 +-- src/llamafactory/chat/hf_engine.py | 42 +-- src/llamafactory/chat/vllm_engine.py | 26 +- src/llamafactory/data/collator.py | 11 +- src/llamafactory/data/converter.py | 3 +- src/llamafactory/data/data_utils.py | 3 +- src/llamafactory/data/loader.py | 3 +- src/llamafactory/data/mm_plugin.py | 259 +++++++++--------- src/llamafactory/data/parser.py | 3 +- src/llamafactory/data/processor/__init__.py | 14 + src/llamafactory/data/processor/feedback.py | 13 +- src/llamafactory/data/processor/pairwise.py | 11 +- src/llamafactory/data/processor/pretrain.py | 2 +- .../data/processor/processor_utils.py | 3 +- src/llamafactory/data/processor/supervised.py | 11 +- .../data/processor/unsupervised.py | 11 +- src/llamafactory/data/template.py | 11 +- src/llamafactory/eval/template.py | 3 +- src/llamafactory/extras/env.py | 2 +- src/llamafactory/extras/logging.py | 2 +- src/llamafactory/extras/misc.py | 5 +- src/llamafactory/extras/packages.py | 2 +- src/llamafactory/hparams/data_args.py | 2 +- src/llamafactory/hparams/model_args.py | 2 +- src/llamafactory/hparams/parser.py | 2 +- src/llamafactory/hparams/training_args.py | 14 + .../model/model_utils/checkpointing.py | 2 +- .../model/model_utils/longlora.py | 2 +- src/llamafactory/model/model_utils/moe.py | 3 +- src/llamafactory/model/model_utils/packing.py | 2 +- .../model/model_utils/quantization.py | 2 +- src/llamafactory/model/model_utils/rope.py | 2 +- src/llamafactory/model/model_utils/visual.py | 3 +- src/llamafactory/train/dpo/trainer.py | 2 +- src/llamafactory/train/dpo/workflow.py | 2 +- src/llamafactory/train/kto/trainer.py | 2 +- src/llamafactory/train/kto/workflow.py | 2 +- src/llamafactory/train/ppo/trainer.py | 2 +- src/llamafactory/train/ppo/workflow.py | 2 +- src/llamafactory/train/pt/workflow.py | 2 +- src/llamafactory/train/rm/trainer.py | 2 +- src/llamafactory/train/rm/workflow.py | 2 +- src/llamafactory/train/sft/metric.py | 6 +- src/llamafactory/train/sft/trainer.py | 2 +- src/llamafactory/train/sft/workflow.py | 2 +- src/llamafactory/train/test_utils.py | 3 +- src/llamafactory/train/trainer_utils.py | 2 +- tests/check_license.py | 38 +++ tests/data/test_mm_plugin.py | 3 +- tests/data/test_template.py | 3 +- 62 files changed, 384 insertions(+), 288 deletions(-) create mode 100644 tests/check_license.py diff --git a/.github/ISSUE_TEMPLATE/1-bug-report.yml b/.github/ISSUE_TEMPLATE/1-bug-report.yml index 7ffb427c..4645bac9 100644 --- a/.github/ISSUE_TEMPLATE/1-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/1-bug-report.yml @@ -47,8 +47,6 @@ body: description: | Please provide entry arguments, error messages and stack traces that reproduces the problem. 请提供入口参数,错误日志以及异常堆栈以便于我们复现问题。 - Remember to wrap your log messages with \`\`\`. - 请务必使用 Markdown 标签 \`\`\` 来包裹您的日志信息。 value: | ```text diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 88dcf8f2..f7219141 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -58,6 +58,10 @@ jobs: run: | make style && make quality + - name: Check license + run: | + make license + - name: Test with pytest run: | make test diff --git a/Makefile b/Makefile index 030b39b1..7eb00286 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build commit quality style test +.PHONY: build commit license quality style test check_dirs := scripts src tests setup.py @@ -9,6 +9,9 @@ commit: pre-commit install pre-commit run --all-files +license: + python3 tests/check_license.py $(check_dirs) + quality: ruff check $(check_dirs) ruff format --check $(check_dirs) diff --git a/data/belle_multiturn/belle_multiturn.py b/data/belle_multiturn/belle_multiturn.py index 2267c7ce..2c2ed4da 100644 --- a/data/belle_multiturn/belle_multiturn.py +++ b/data/belle_multiturn/belle_multiturn.py @@ -1,3 +1,18 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os diff --git a/data/hh_rlhf_en/hh_rlhf_en.py b/data/hh_rlhf_en/hh_rlhf_en.py index 083130f1..287eac40 100644 --- a/data/hh_rlhf_en/hh_rlhf_en.py +++ b/data/hh_rlhf_en/hh_rlhf_en.py @@ -1,3 +1,18 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os diff --git a/data/ultra_chat/ultra_chat.py b/data/ultra_chat/ultra_chat.py index 9eafa2ef..2ce17204 100644 --- a/data/ultra_chat/ultra_chat.py +++ b/data/ultra_chat/ultra_chat.py @@ -1,3 +1,18 @@ +# Copyright 2025 the LlamaFactory team. +# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os diff --git a/evaluation/ceval/ceval.py b/evaluation/ceval/ceval.py index e18be8ee..72693ebf 100644 --- a/evaluation/ceval/ceval.py +++ b/evaluation/ceval/ceval.py @@ -1,3 +1,4 @@ +# Copyright 2025 the LlamaFactory team. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/evaluation/cmmlu/cmmlu.py b/evaluation/cmmlu/cmmlu.py index 517d63f8..44c52f1b 100644 --- a/evaluation/cmmlu/cmmlu.py +++ b/evaluation/cmmlu/cmmlu.py @@ -1,3 +1,4 @@ +# Copyright 2025 the LlamaFactory team. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/evaluation/mmlu/mmlu.py b/evaluation/mmlu/mmlu.py index 63547757..63127426 100644 --- a/evaluation/mmlu/mmlu.py +++ b/evaluation/mmlu/mmlu.py @@ -1,3 +1,4 @@ +# Copyright 2025 the LlamaFactory team. # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/scripts/api_example/test_toolcall.py b/scripts/api_example/test_toolcall.py index 2dff4aab..8e933914 100644 --- a/scripts/api_example/test_toolcall.py +++ b/scripts/api_example/test_toolcall.py @@ -14,7 +14,6 @@ import json import os -from collections.abc import Sequence from openai import OpenAI from transformers.utils.versions import require_version @@ -23,7 +22,7 @@ from transformers.utils.versions import require_version require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") -def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float: +def calculate_gpa(grades: list[str], hours: list[int]) -> float: grade_to_score = {"A": 4, "B": 3, "C": 2} total_score, total_hour = 0, 0 for grade, hour in zip(grades, hours): diff --git a/scripts/stat_utils/cal_ppl.py b/scripts/stat_utils/cal_ppl.py index a318ee46..8d47ffd8 100644 --- a/scripts/stat_utils/cal_ppl.py +++ b/scripts/stat_utils/cal_ppl.py @@ -13,7 +13,6 @@ # limitations under the License. import json -from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Literal, Optional @@ -35,7 +34,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): train_on_prompt: bool = False - def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, torch.Tensor]: + def __call__(self, features: list[dict[str, Any]]) -> dict[str, torch.Tensor]: r"""Pad batched data to the longest sequence in the batch.""" chosen_features = [] for feature in features: diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py index 3b9bf5f4..6d497c1a 100644 --- a/src/llamafactory/chat/base_engine.py +++ b/src/llamafactory/chat/base_engine.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Optional, Union @@ -63,12 +63,12 @@ class BaseEngine(ABC): @abstractmethod async def chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> list["Response"]: r"""Get a list of responses of the chat model.""" @@ -77,12 +77,12 @@ class BaseEngine(ABC): @abstractmethod async def stream_chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: r"""Get the response token-by-token of the chat model.""" diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 63651184..8a604619 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -1,4 +1,4 @@ -# Copyright 2024 THUDM and the LlamaFactory team. +# Copyright 2025 THUDM and the LlamaFactory team. # # This code is inspired by the THUDM's ChatGLM implementation. # https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py @@ -17,7 +17,7 @@ import asyncio import os -from collections.abc import AsyncGenerator, Generator, Sequence +from collections.abc import AsyncGenerator, Generator from threading import Thread from typing import TYPE_CHECKING, Any, Optional @@ -61,12 +61,12 @@ class ChatModel: def chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> list["Response"]: r"""Get a list of responses of the chat model.""" @@ -77,12 +77,12 @@ class ChatModel: async def achat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> list["Response"]: r"""Asynchronously get a list of responses of the chat model.""" @@ -90,12 +90,12 @@ class ChatModel: def stream_chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> Generator[str, None, None]: r"""Get the response token-by-token of the chat model.""" @@ -109,12 +109,12 @@ class ChatModel: async def astream_chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: r"""Asynchronously get the response token-by-token of the chat model.""" diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 510cf0c6..813c0976 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -15,7 +15,7 @@ import asyncio import concurrent.futures import os -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator from threading import Thread from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -78,12 +78,12 @@ class HuggingfaceEngine(BaseEngine): processor: Optional["ProcessorMixin"], template: "Template", generating_args: dict[str, Any], - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, input_kwargs: Optional[dict[str, Any]] = {}, ) -> tuple[dict[str, Any], int]: mm_input_dict = {"images": [], "videos": [], "audios": [], "imglens": [0], "vidlens": [0], "audlens": [0]} @@ -219,12 +219,12 @@ class HuggingfaceEngine(BaseEngine): processor: Optional["ProcessorMixin"], template: "Template", generating_args: dict[str, Any], - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, input_kwargs: Optional[dict[str, Any]] = {}, ) -> list["Response"]: gen_kwargs, prompt_length = HuggingfaceEngine._process_args( @@ -274,12 +274,12 @@ class HuggingfaceEngine(BaseEngine): processor: Optional["ProcessorMixin"], template: "Template", generating_args: dict[str, Any], - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, input_kwargs: Optional[dict[str, Any]] = {}, ) -> Callable[[], str]: gen_kwargs, _ = HuggingfaceEngine._process_args( @@ -338,12 +338,12 @@ class HuggingfaceEngine(BaseEngine): @override async def chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> list["Response"]: if not self.can_generate: @@ -371,12 +371,12 @@ class HuggingfaceEngine(BaseEngine): @override async def stream_chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: if not self.can_generate: diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index ab478728..4d37e81f 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -13,7 +13,7 @@ # limitations under the License. import uuid -from collections.abc import AsyncGenerator, AsyncIterator, Sequence +from collections.abc import AsyncGenerator, AsyncIterator from typing import TYPE_CHECKING, Any, Optional, Union from typing_extensions import override @@ -102,12 +102,12 @@ class VllmEngine(BaseEngine): async def _generate( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = f"chatcmpl-{uuid.uuid4().hex}" @@ -202,12 +202,12 @@ class VllmEngine(BaseEngine): @override async def chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> list["Response"]: final_output = None @@ -231,12 +231,12 @@ class VllmEngine(BaseEngine): @override async def stream_chat( self, - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - images: Optional[Sequence["ImageInput"]] = None, - videos: Optional[Sequence["VideoInput"]] = None, - audios: Optional[Sequence["AudioInput"]] = None, + images: Optional[list["ImageInput"]] = None, + videos: Optional[list["VideoInput"]] = None, + audios: Optional[list["AudioInput"]] = None, **input_kwargs, ) -> AsyncGenerator[str, None]: generated_text = "" diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index be9d9eb0..4de6bc2c 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -1,4 +1,4 @@ -# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. +# Copyright 2025 OpenAccess AI Collective and the LlamaFactory team. # # This code is inspired by the OpenAccess AI Collective's axolotl library. # https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Optional @@ -92,7 +91,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if self.template is None: raise ValueError("Template is required for MultiModalDataCollator.") - def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: batch_images, batch_videos, batch_audios = [], [], [] batch_imglens, batch_vidlens, batch_audlens, batch_input_ids = [], [], [], [] for feature in features: @@ -205,7 +204,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" compute_dtype: "torch.dtype" = torch.float32 - def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: features = super().__call__(features) if self.block_diag_attn and self.attn_implementation != "flash_attention_2": features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) @@ -221,7 +220,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): r"""Data collator for pairwise data.""" - def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: r"""Pad batched data to the longest sequence in the batch. We generate 2 * n examples where the first n examples represent chosen examples and @@ -247,7 +246,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): r"""Data collator for KTO data.""" - def __call__(self, features: Sequence[dict[str, Any]]) -> dict[str, "torch.Tensor"]: + def __call__(self, features: list[dict[str, Any]]) -> dict[str, "torch.Tensor"]: target_features = [] kl_features = [] kto_tags = [] diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py index 8449d7c5..25f39545 100644 --- a/src/llamafactory/data/converter.py +++ b/src/llamafactory/data/converter.py @@ -14,7 +14,6 @@ import os from abc import abstractmethod -from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional, Union @@ -37,7 +36,7 @@ class DatasetConverter: dataset_attr: "DatasetAttr" data_args: "DataArguments" - def _find_medias(self, medias: Union[Any, Sequence[Any]]) -> Optional[list[Any]]: + def _find_medias(self, medias: Union[Any, list[Any]]) -> Optional[list[Any]]: r"""Optionally concatenate media path to media dir when loading from local disk.""" if not isinstance(medias, list): medias = [medias] if medias is not None else [] diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 13a5b4cb..d3184fb6 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from enum import Enum, unique from typing import TYPE_CHECKING, Optional, TypedDict, Union @@ -30,7 +29,7 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -SLOTS = Sequence[Union[str, set[str], dict[str, str]]] +SLOTS = list[Union[str, set[str], dict[str, str]]] @unique diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 5f243329..e495cc31 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -13,7 +13,6 @@ # limitations under the License. import os -from collections.abc import Sequence from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np @@ -157,7 +156,7 @@ def _load_single_dataset( def _get_merged_dataset( - dataset_names: Optional[Sequence[str]], + dataset_names: Optional[list[str]], model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index ffcb1408..b7e6852d 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -18,7 +18,6 @@ import inspect import math import re -from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass from io import BytesIO @@ -83,9 +82,7 @@ if TYPE_CHECKING: pass -def _get_paligemma_token_type_ids( - imglens: Sequence[int], seqlens: Sequence[int], processor: "MMProcessor" -) -> list[list[int]]: +def _get_paligemma_token_type_ids(imglens: list[int], seqlens: list[int], processor: "MMProcessor") -> list[list[int]]: r"""Get paligemma token type ids for computing loss. It is slightly different with the original token type ids where the prompt part is 0. @@ -120,7 +117,7 @@ def _get_gemma3_token_type_ids(batch_ids: list[list[int]], processor: "MMProcess return batch_token_type_ids -def _make_batched_images(images: Sequence["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]: +def _make_batched_images(images: list["ImageObject"], imglens: list[int]) -> list[list["ImageObject"]]: r"""Make nested list of images.""" batch_images = [] for imglen in imglens: @@ -140,9 +137,9 @@ class MMPluginMixin: def _validate_input( self, processor: Optional["MMProcessor"], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], ) -> None: r"""Validate if this model accepts the input modalities.""" image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) @@ -202,7 +199,7 @@ class MMPluginMixin: sample_frames = min(total_frames, video_maxlen, sample_frames) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) - def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> list["ImageObject"]: + def _regularize_images(self, images: list["ImageInput"], **kwargs) -> list["ImageObject"]: r"""Regularize images to avoid error. Including reading and pre-processing.""" results = [] for image in images: @@ -223,7 +220,7 @@ class MMPluginMixin: return results - def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> list[list["ImageObject"]]: + def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> list[list["ImageObject"]]: r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" results = [] for video in videos: @@ -241,7 +238,7 @@ class MMPluginMixin: return results - def _regularize_audios(self, audios: Sequence["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]: + def _regularize_audios(self, audios: list["AudioInput"], sampling_rate: float, **kwargs) -> list["NDArray"]: r"""Regularizes audios to avoid error. Including reading and resampling.""" results = [] for audio in audios: @@ -257,9 +254,9 @@ class MMPluginMixin: def _get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: "MMProcessor", imglens: Optional[list[int]] = None, ) -> dict[str, "torch.Tensor"]: @@ -335,10 +332,10 @@ class MMPluginMixin: class BasePlugin(MMPluginMixin): def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: r"""Pre-process input messages before tokenization for VLMs.""" @@ -349,9 +346,9 @@ class BasePlugin(MMPluginMixin): self, input_ids: list[int], labels: Optional[list[int]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["MMProcessor"], ) -> tuple[list[int], Optional[list[int]]]: @@ -361,13 +358,13 @@ class BasePlugin(MMPluginMixin): def get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: r"""Build batched multimodal inputs for VLMs. @@ -392,10 +389,10 @@ class Gemma3Plugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -420,13 +417,13 @@ class Gemma3Plugin(BasePlugin): @override def get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) @@ -441,10 +438,10 @@ class LlavaPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -481,10 +478,10 @@ class LlavaNextPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -523,10 +520,10 @@ class LlavaNextVideoPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -586,10 +583,10 @@ class MiniCPMVPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -686,9 +683,9 @@ class MiniCPMVPlugin(BasePlugin): @override def _get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: "MMProcessor", **kwargs, ) -> dict[str, "torch.Tensor"]: @@ -757,13 +754,13 @@ class MiniCPMVPlugin(BasePlugin): @override def get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) @@ -828,10 +825,10 @@ class MllamaPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -850,13 +847,13 @@ class MllamaPlugin(BasePlugin): @override def get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) @@ -885,10 +882,10 @@ class PaliGemmaPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -912,9 +909,9 @@ class PaliGemmaPlugin(BasePlugin): self, input_ids: list[int], labels: Optional[list[int]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["MMProcessor"], ) -> tuple[list[int], Optional[list[int]]]: @@ -931,13 +928,13 @@ class PaliGemmaPlugin(BasePlugin): @override def get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) @@ -952,10 +949,10 @@ class PixtralPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -995,13 +992,13 @@ class PixtralPlugin(BasePlugin): @override def get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) @@ -1015,10 +1012,10 @@ class Qwen2AudioPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -1056,13 +1053,13 @@ class Qwen2AudioPlugin(BasePlugin): @override def get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) @@ -1090,7 +1087,7 @@ class Qwen2VLPlugin(BasePlugin): @override def _regularize_videos( - self, videos: Sequence["VideoInput"], **kwargs + self, videos: list["VideoInput"], **kwargs ) -> tuple[list[list["ImageObject"]], list[float]]: results, fps_per_video = [], [] for video in videos: @@ -1118,9 +1115,9 @@ class Qwen2VLPlugin(BasePlugin): @override def _get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: "MMProcessor", ) -> dict[str, "torch.Tensor"]: image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) @@ -1149,10 +1146,10 @@ class Qwen2VLPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) @@ -1204,13 +1201,13 @@ class Qwen2VLPlugin(BasePlugin): @override def get_mm_inputs( self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - audlens: Sequence[int], - batch_ids: Sequence[list[int]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], processor: Optional["MMProcessor"], ) -> dict[str, Union[list[int], "torch.Tensor"]]: self._validate_input(processor, images, videos, audios) @@ -1229,10 +1226,10 @@ class VideoLlavaPlugin(BasePlugin): @override def process_messages( self, - messages: Sequence[dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], processor: Optional["MMProcessor"], ) -> list[dict[str, str]]: self._validate_input(processor, images, videos, audios) diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 4e1c7aff..ccc1bdcd 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -14,7 +14,6 @@ import json import os -from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Literal, Optional @@ -91,7 +90,7 @@ class DatasetAttr: self.set_attr(tag, attr["tags"]) -def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -> list["DatasetAttr"]: +def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: str) -> list["DatasetAttr"]: r"""Get the attributes of the datasets.""" if dataset_names is None: dataset_names = [] diff --git a/src/llamafactory/data/processor/__init__.py b/src/llamafactory/data/processor/__init__.py index a827d005..357ab789 100644 --- a/src/llamafactory/data/processor/__init__.py +++ b/src/llamafactory/data/processor/__init__.py @@ -1,3 +1,17 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .feedback import FeedbackDatasetProcessor from .pairwise import PairwiseDatasetProcessor from .pretrain import PretrainDatasetProcessor diff --git a/src/llamafactory/data/processor/feedback.py b/src/llamafactory/data/processor/feedback.py index 89233e10..ed9359ae 100644 --- a/src/llamafactory/data/processor/feedback.py +++ b/src/llamafactory/data/processor/feedback.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import defaultdict -from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Optional from ...extras import logging @@ -31,14 +30,14 @@ logger = logging.get_logger(__name__) class FeedbackDatasetProcessor(DatasetProcessor): def _encode_data_example( self, - prompt: Sequence[dict[str, str]], - response: Sequence[dict[str, str]], - kl_response: Sequence[dict[str, str]], + prompt: list[dict[str, str]], + response: list[dict[str, str]], + kl_response: list[dict[str, str]], system: Optional[str], tools: Optional[str], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], ) -> tuple[list[int], list[int], list[int], list[int], bool]: if response[0]["content"]: # desired example kto_tag = True diff --git a/src/llamafactory/data/processor/pairwise.py b/src/llamafactory/data/processor/pairwise.py index e0a81f0b..94101deb 100644 --- a/src/llamafactory/data/processor/pairwise.py +++ b/src/llamafactory/data/processor/pairwise.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import defaultdict -from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Optional from ...extras import logging @@ -31,13 +30,13 @@ logger = logging.get_logger(__name__) class PairwiseDatasetProcessor(DatasetProcessor): def _encode_data_example( self, - prompt: Sequence[dict[str, str]], - response: Sequence[dict[str, str]], + prompt: list[dict[str, str]], + response: list[dict[str, str]], system: Optional[str], tools: Optional[str], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], ) -> tuple[list[int], list[int], list[int], list[int]]: chosen_messages = self.template.mm_plugin.process_messages( prompt + [response[0]], images, videos, audios, self.processor diff --git a/src/llamafactory/data/processor/pretrain.py b/src/llamafactory/data/processor/pretrain.py index 385b3914..3fa6b1ca 100644 --- a/src/llamafactory/data/processor/pretrain.py +++ b/src/llamafactory/data/processor/pretrain.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py diff --git a/src/llamafactory/data/processor/processor_utils.py b/src/llamafactory/data/processor/processor_utils.py index 528ff52b..db44b19c 100644 --- a/src/llamafactory/data/processor/processor_utils.py +++ b/src/llamafactory/data/processor/processor_utils.py @@ -14,7 +14,6 @@ import bisect from abc import ABC, abstractmethod -from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional @@ -46,7 +45,7 @@ class DatasetProcessor(ABC): ... -def search_for_fit(numbers: Sequence[int], capacity: int) -> int: +def search_for_fit(numbers: list[int], capacity: int) -> int: r"""Find the index of largest number that fits into the knapsack with the given capacity.""" index = bisect.bisect(numbers, capacity) return -1 if index == 0 else (index - 1) diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py index 1e62e9a3..d2ef508b 100644 --- a/src/llamafactory/data/processor/supervised.py +++ b/src/llamafactory/data/processor/supervised.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import defaultdict -from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional @@ -33,13 +32,13 @@ logger = logging.get_logger(__name__) class SupervisedDatasetProcessor(DatasetProcessor): def _encode_data_example( self, - prompt: Sequence[dict[str, str]], - response: Sequence[dict[str, str]], + prompt: list[dict[str, str]], + response: list[dict[str, str]], system: Optional[str], tools: Optional[str], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], ) -> tuple[list[int], list[int]]: messages = self.template.mm_plugin.process_messages(prompt + response, images, videos, audios, self.processor) input_ids, labels = self.template.mm_plugin.process_token_ids( diff --git a/src/llamafactory/data/processor/unsupervised.py b/src/llamafactory/data/processor/unsupervised.py index 2ce628d9..256174b6 100644 --- a/src/llamafactory/data/processor/unsupervised.py +++ b/src/llamafactory/data/processor/unsupervised.py @@ -13,7 +13,6 @@ # limitations under the License. from collections import defaultdict -from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Optional from ...extras import logging @@ -31,13 +30,13 @@ logger = logging.get_logger(__name__) class UnsupervisedDatasetProcessor(DatasetProcessor): def _encode_data_example( self, - prompt: Sequence[dict[str, str]], - response: Sequence[dict[str, str]], + prompt: list[dict[str, str]], + response: list[dict[str, str]], system: Optional[str], tools: Optional[str], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - audios: Sequence["AudioInput"], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], ) -> tuple[list[int], list[int]]: if len(response) == 1: messages = prompt + response diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 9454d40e..cc4a6dcb 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Optional, Union @@ -57,7 +56,7 @@ class Template: def encode_oneturn( self, tokenizer: "PreTrainedTokenizer", - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, ) -> tuple[list[int], list[int]]: @@ -73,7 +72,7 @@ class Template: def encode_multiturn( self, tokenizer: "PreTrainedTokenizer", - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, ) -> list[tuple[list[int], list[int]]]: @@ -115,7 +114,7 @@ class Template: def _encode( self, tokenizer: "PreTrainedTokenizer", - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: Optional[str], tools: Optional[str], ) -> list[list[int]]: @@ -316,7 +315,7 @@ class Llama2Template(Template): def _encode( self, tokenizer: "PreTrainedTokenizer", - messages: Sequence[dict[str, str]], + messages: list[dict[str, str]], system: str, tools: str, ) -> list[list[int]]: @@ -391,7 +390,7 @@ def register_template( format_tools: Optional["Formatter"] = None, format_prefix: Optional["Formatter"] = None, default_system: str = "", - stop_words: Optional[Sequence[str]] = None, + stop_words: Optional[list[str]] = None, thought_words: Optional[tuple[str, str]] = None, efficient_eos: bool = False, replace_eos: bool = False, diff --git a/src/llamafactory/eval/template.py b/src/llamafactory/eval/template.py index 83f70171..57424697 100644 --- a/src/llamafactory/eval/template.py +++ b/src/llamafactory/eval/template.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from dataclasses import dataclass from ..data import Role @@ -35,7 +34,7 @@ class EvalTemplate: return "".join([example["question"]] + candidates + [self.answer]), example["answer"] def format_example( - self, target_data: dict[str, str], support_set: Sequence[dict[str, str]], subject_name: str + self, target_data: dict[str, str], support_set: list[dict[str, str]], subject_name: str ) -> list[dict[str, str]]: r"""Convert dataset examples to messages.""" messages = [] diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py index ef099397..659c1cf7 100644 --- a/src/llamafactory/extras/env.py +++ b/src/llamafactory/extras/env.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 8fc030a8..a6e0a4e3 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -1,4 +1,4 @@ -# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 Optuna, HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 076a4178..235b3043 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's PEFT library. # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py @@ -17,7 +17,6 @@ import gc import os -from collections.abc import Sequence from typing import TYPE_CHECKING, Any, Literal, Union import torch @@ -98,7 +97,7 @@ def check_dependencies() -> None: logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.") -def calculate_tps(dataset: Sequence[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float: +def calculate_tps(dataset: list[dict[str, Any]], metrics: dict[str, float], stage: Literal["sft", "rm"]) -> float: r"""Calculate effective tokens per second.""" effective_token_num = 0 for data in dataset: diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 62a3615a..b474633e 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 80b49248..3a66b2c0 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 2c7c00f8..3b6a1eea 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 464b8472..9ec819c7 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index 38a93650..dfc41caf 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -1,3 +1,17 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json from dataclasses import dataclass, field from typing import Literal, Optional, Union diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index eb498310..051f6b04 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc., Daniel Han-Chen & the Unsloth team and the LlamaFactory team. # # This code is inspired by the HuggingFace's Transformers and PEFT library, # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 12ea91e5..3413945e 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -1,4 +1,4 @@ -# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team. +# Copyright 2025 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team. # # This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries. # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 9e225ad5..bc4f2906 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from typing import TYPE_CHECKING import torch @@ -27,7 +26,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: +def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: list["torch.nn.Module"]) -> None: check_version("deepspeed>=0.13.0") from deepspeed.utils import set_z3_leaf_modules # type: ignore diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 475d7bc3..824720d2 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -1,4 +1,4 @@ -# Copyright 2024 Musab Gultekin and the LlamaFactory team. +# Copyright 2025 Musab Gultekin and the LlamaFactory team. # # This code is based on the Musab Gultekin's functionary library. # https://github.com/MeetKai/functionary/blob/main/functionary/train/packing/monkey_patch_packing.py diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 860e2c2a..e33fed86 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's Transformers and Optimum library. # https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py index ccb9daf1..30d0fdd7 100644 --- a/src/llamafactory/model/model_utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -1,4 +1,4 @@ -# Copyright 2024 LMSYS and the LlamaFactory team. +# Copyright 2025 LMSYS and the LlamaFactory team. # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li # # This code is inspired by the LMSYS's FastChat library. diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 01c20988..76162802 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Optional @@ -180,7 +179,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni def patch_target_modules( - model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str] + model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: list[str] ) -> list[str]: r"""Freeze vision tower for VLM LoRA tuning.""" model_type = getattr(model.config, "model_type", None) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index e8499416..98c22022 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's TRL library. # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/dpo_trainer.py diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 7a9ff517..422a702e 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's TRL library. # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index d4c07092..0409c305 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's TRL library. # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index 74668720..45f82671 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's TRL library. # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 00258acd..1684fb17 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's TRL library. # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 4e64a256..fa6629a2 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's TRL library. # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index ecbbe00d..3c04595a 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 508f7516..8c14b0ab 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index 14607cca..cb693187 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py diff --git a/src/llamafactory/train/sft/metric.py b/src/llamafactory/train/sft/metric.py index 323d6241..f4f73ee4 100644 --- a/src/llamafactory/train/sft/metric.py +++ b/src/llamafactory/train/sft/metric.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc., THUDM, and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py @@ -37,11 +37,11 @@ if is_jieba_available(): if is_nltk_available(): - from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu # type: ignore if is_rouge_available(): - from rouge_chinese import Rouge + from rouge_chinese import Rouge # type: ignore def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index ea22c6bb..f90edeab 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 1474a74a..37006707 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the HuggingFace's transformers library. # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py diff --git a/src/llamafactory/train/test_utils.py b/src/llamafactory/train/test_utils.py index ceffe1e0..6e4c4ffc 100644 --- a/src/llamafactory/train/test_utils.py +++ b/src/llamafactory/train/test_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Sequence from typing import TYPE_CHECKING, Optional, Union import torch @@ -33,7 +32,7 @@ if TYPE_CHECKING: from ..data.data_utils import DatasetModule -def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: Sequence[str] = []) -> None: +def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_keys: list[str] = []) -> None: state_dict_a = model_a.state_dict() state_dict_b = model_b.state_dict() assert set(state_dict_a.keys()) == set(state_dict_b.keys()) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 5f6aeb8f..198936eb 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -1,4 +1,4 @@ -# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# Copyright 2025 HuggingFace Inc. and the LlamaFactory team. # # This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore # and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus diff --git a/tests/check_license.py b/tests/check_license.py new file mode 100644 index 00000000..853d2399 --- /dev/null +++ b/tests/check_license.py @@ -0,0 +1,38 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + + +KEYWORDS = ("Copyright", "2025", "LlamaFactory") + + +def main(): + path_list = [] + for check_dir in sys.argv[1:]: + path_list.extend(Path(check_dir).glob("**/*.py")) + + for path in path_list: + with open(path.absolute(), encoding="utf-8") as f: + file_content = f.read().strip().split("\n") + if not file_content[0]: + continue + + print(f"Check license: {path}") + assert all(keyword in file_content[0] for keyword in KEYWORDS), f"File {path} does not contain license." + + +if __name__ == "__main__": + main() diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index c47842bc..e064d195 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -13,7 +13,6 @@ # limitations under the License. import os -from collections.abc import Sequence from typing import TYPE_CHECKING, Any import pytest @@ -97,7 +96,7 @@ def _check_plugin( plugin: "BasePlugin", tokenizer: "PreTrainedTokenizer", processor: "ProcessorMixin", - expected_mm_messages: Sequence[dict[str, str]] = MM_MESSAGES, + expected_mm_messages: list[dict[str, str]] = MM_MESSAGES, expected_input_ids: list[int] = INPUT_IDS, expected_labels: list[int] = LABELS, expected_mm_inputs: dict[str, Any] = {}, diff --git a/tests/data/test_template.py b/tests/data/test_template.py index b3f2052e..0f8bc6cc 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -13,7 +13,6 @@ # limitations under the License. import os -from collections.abc import Sequence from typing import TYPE_CHECKING import pytest @@ -41,7 +40,7 @@ MESSAGES = [ def _check_tokenization( - tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str] + tokenizer: "PreTrainedTokenizer", batch_input_ids: list[list[int]], batch_text: list[str] ) -> None: r"""Check token ids and texts.