From eceec8ab694da16dc31c60c8a2b75af73c13782f Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Sat, 27 Dec 2025 02:50:44 +0800 Subject: [PATCH] [deps] goodbye python 3.9 (#9677) Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> Co-authored-by: hiyouga --- .github/workflows/tests.yml | 12 ++-- pyproject.toml | 32 +++++----- scripts/megatron_merge.py | 7 +-- scripts/stat_utils/cal_ppl.py | 4 +- scripts/vllm_infer.py | 7 +-- src/llamafactory/api/app.py | 4 +- src/llamafactory/api/protocol.py | 41 +++++++------ src/llamafactory/chat/hf_engine.py | 4 +- src/llamafactory/data/converter.py | 4 +- src/llamafactory/data/formatter.py | 7 +-- src/llamafactory/data/loader.py | 8 +-- src/llamafactory/data/mm_plugin.py | 30 +++++----- src/llamafactory/data/parser.py | 52 ++++++++-------- src/llamafactory/extras/constants.py | 3 +- src/llamafactory/extras/logging.py | 2 +- src/llamafactory/hparams/data_args.py | 28 ++++----- src/llamafactory/hparams/evaluation_args.py | 4 +- src/llamafactory/hparams/finetuning_args.py | 52 ++++++++-------- src/llamafactory/hparams/model_args.py | 59 +++++++++---------- src/llamafactory/hparams/parser.py | 22 +++---- src/llamafactory/hparams/training_args.py | 10 ++-- .../model/model_utils/checkpointing.py | 3 +- src/llamafactory/train/rm/metric.py | 6 +- src/llamafactory/train/trainer_utils.py | 4 +- src/llamafactory/v1/accelerator/helper.py | 3 +- src/llamafactory/v1/accelerator/interface.py | 20 +++---- src/llamafactory/v1/config/arg_parser.py | 4 +- src/llamafactory/v1/config/arg_utils.py | 5 +- src/llamafactory/v1/config/data_args.py | 3 +- src/llamafactory/v1/config/model_args.py | 7 +-- src/llamafactory/v1/config/training_args.py | 3 +- src/llamafactory/v1/core/data_engine.py | 4 +- .../v1/plugins/data_plugins/converter.py | 4 +- .../v1/plugins/data_plugins/loader.py | 8 +-- .../v1/plugins/data_plugins/template.py | 5 +- .../plugins/model_plugins/kernels/registry.py | 15 ++--- .../v1/plugins/model_plugins/peft.py | 4 +- src/llamafactory/v1/utils/dtype.py | 13 ++-- src/llamafactory/v1/utils/logging.py | 2 +- src/llamafactory/v1/utils/plugin.py | 4 +- src/llamafactory/v1/utils/types.py | 4 +- src/llamafactory/webui/chatter.py | 10 ++-- src/llamafactory/webui/common.py | 8 +-- src/llamafactory/webui/components/export.py | 6 +- src/llamafactory/webui/control.py | 4 +- src/llamafactory/webui/runner.py | 4 +- tests/conftest.py | 3 +- tests_v1/conftest.py | 3 +- 48 files changed, 267 insertions(+), 284 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 036ff744e..3def9eb89 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,10 +25,9 @@ jobs: fail-fast: false matrix: python: - - "3.9" - - "3.10" - "3.11" - "3.12" + # - "3.13" # enable after trl is upgraded os: - "ubuntu-latest" - "windows-latest" @@ -36,18 +35,15 @@ jobs: transformers: - null include: # test backward compatibility - - python: "3.9" + - python: "3.11" os: "ubuntu-latest" transformers: "4.49.0" - - python: "3.9" + - python: "3.11" os: "ubuntu-latest" transformers: "4.51.0" - - python: "3.9" + - python: "3.11" os: "ubuntu-latest" transformers: "4.53.0" - exclude: # exclude python 3.9 on macos - - python: "3.9" - os: "macos-latest" runs-on: ${{ matrix.os }} diff --git a/pyproject.toml b/pyproject.toml index 732a812c6..9c8a9cb6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "Unified Efficient Fine-Tuning of 100+ LLMs" readme = "README.md" license = "Apache-2.0" -requires-python = ">=3.9.0" +requires-python = ">=3.11.0" authors = [ { name = "hiyouga", email = "hiyouga@buaa.edu.cn" } ] @@ -30,10 +30,10 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Operating System :: OS Independent", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Artificial Intelligence" ] dependencies = [ @@ -98,24 +98,26 @@ path = "src/llamafactory/extras/env.py" pattern = "VERSION = \"(?P[^\"]+)\"" [tool.ruff] -target-version = "py39" +target-version = "py311" line-length = 119 indent-width = 4 [tool.ruff.lint] ignore = [ - "C408", # collection - "C901", # complex - "E501", # line too long - "E731", # lambda function - "E741", # ambiguous var name - "D100", # no doc public module - "D101", # no doc public class - "D102", # no doc public method - "D103", # no doc public function - "D104", # no doc public package - "D105", # no doc magic method - "D107", # no doc __init__ + "C408", # collection + "C901", # complex + "E501", # line too long + "E731", # lambda function + "E741", # ambiguous var name + "UP007", # no upgrade union + "UP045", # no upgrade optional + "D100", # no doc public module + "D101", # no doc public class + "D102", # no doc public method + "D103", # no doc public function + "D104", # no doc public package + "D105", # no doc magic method + "D107", # no doc __init__ ] extend-select = [ "C", # complexity diff --git a/scripts/megatron_merge.py b/scripts/megatron_merge.py index 47ad98f0e..4d9d932cd 100644 --- a/scripts/megatron_merge.py +++ b/scripts/megatron_merge.py @@ -16,7 +16,6 @@ # limitations under the License. import os -from typing import Optional import fire import torch @@ -34,7 +33,7 @@ def convert_mca_to_hf( output_path: str = "./output", bf16: bool = False, fp16: bool = False, - convert_model_max_length: Optional[int] = None, + convert_model_max_length: int | None = None, ): """Convert megatron checkpoint to HuggingFace format. @@ -67,11 +66,11 @@ def convert( output_path: str = "./output", bf16: bool = False, fp16: bool = False, - convert_model_max_length: Optional[int] = None, + convert_model_max_length: int | None = None, tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, expert_model_parallel_size: int = 1, - virtual_pipeline_model_parallel_size: Optional[int] = None, + virtual_pipeline_model_parallel_size: int | None = None, ): """Convert checkpoint between MCA and HuggingFace formats. diff --git a/scripts/stat_utils/cal_ppl.py b/scripts/stat_utils/cal_ppl.py index 8d47ffd87..56b3c8d11 100644 --- a/scripts/stat_utils/cal_ppl.py +++ b/scripts/stat_utils/cal_ppl.py @@ -14,7 +14,7 @@ import json from dataclasses import dataclass -from typing import Any, Literal, Optional +from typing import Any, Literal import fire import torch @@ -61,7 +61,7 @@ def calculate_ppl( dataset_dir: str = "data", template: str = "default", cutoff_len: int = 2048, - max_samples: Optional[int] = None, + max_samples: int | None = None, train_on_prompt: bool = False, ): r"""Calculate the ppl on the dataset of the pre-trained models. diff --git a/scripts/vllm_infer.py b/scripts/vllm_infer.py index 4d6f05862..6c157d6dd 100644 --- a/scripts/vllm_infer.py +++ b/scripts/vllm_infer.py @@ -14,7 +14,6 @@ import gc import json -from typing import Optional import av import fire @@ -49,7 +48,7 @@ def vllm_infer( dataset_dir: str = "data", template: str = "default", cutoff_len: int = 2048, - max_samples: Optional[int] = None, + max_samples: int | None = None, vllm_config: str = "{}", save_name: str = "generated_predictions.jsonl", temperature: float = 0.95, @@ -58,9 +57,9 @@ def vllm_infer( max_new_tokens: int = 1024, repetition_penalty: float = 1.0, skip_special_tokens: bool = True, - default_system: Optional[str] = None, + default_system: str | None = None, enable_thinking: bool = True, - seed: Optional[int] = None, + seed: int | None = None, pipeline_parallel_size: int = 1, image_max_pixels: int = 768 * 768, image_min_pixels: int = 32 * 32, diff --git a/src/llamafactory/api/app.py b/src/llamafactory/api/app.py index e0621d80b..8ec0679cb 100644 --- a/src/llamafactory/api/app.py +++ b/src/llamafactory/api/app.py @@ -16,7 +16,7 @@ import asyncio import os from contextlib import asynccontextmanager from functools import partial -from typing import Annotated, Optional +from typing import Annotated from ..chat import ChatModel from ..extras.constants import EngineName @@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": api_key = os.getenv("API_KEY") security = HTTPBearer(auto_error=False) - async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): + async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]): if api_key and (auth is None or auth.credentials != api_key): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py index 889d938e0..675523f06 100644 --- a/src/llamafactory/api/protocol.py +++ b/src/llamafactory/api/protocol.py @@ -14,10 +14,9 @@ import time from enum import Enum, unique -from typing import Any, Optional, Union +from typing import Any, Literal from pydantic import BaseModel, Field -from typing_extensions import Literal @unique @@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel): class FunctionAvailable(BaseModel): type: Literal["function", "code_interpreter"] = "function" - function: Optional[FunctionDefinition] = None + function: FunctionDefinition | None = None class FunctionCall(BaseModel): @@ -77,35 +76,35 @@ class URL(BaseModel): class MultimodalInputItem(BaseModel): type: Literal["text", "image_url", "video_url", "audio_url"] - text: Optional[str] = None - image_url: Optional[URL] = None - video_url: Optional[URL] = None - audio_url: Optional[URL] = None + text: str | None = None + image_url: URL | None = None + video_url: URL | None = None + audio_url: URL | None = None class ChatMessage(BaseModel): role: Role - content: Optional[Union[str, list[MultimodalInputItem]]] = None - tool_calls: Optional[list[FunctionCall]] = None + content: str | list[MultimodalInputItem] | None = None + tool_calls: list[FunctionCall] | None = None class ChatCompletionMessage(BaseModel): - role: Optional[Role] = None - content: Optional[str] = None - tool_calls: Optional[list[FunctionCall]] = None + role: Role | None = None + content: str | None = None + tool_calls: list[FunctionCall] | None = None class ChatCompletionRequest(BaseModel): model: str messages: list[ChatMessage] - tools: Optional[list[FunctionAvailable]] = None - do_sample: Optional[bool] = None - temperature: Optional[float] = None - top_p: Optional[float] = None + tools: list[FunctionAvailable] | None = None + do_sample: bool | None = None + temperature: float | None = None + top_p: float | None = None n: int = 1 - presence_penalty: Optional[float] = None - max_tokens: Optional[int] = None - stop: Optional[Union[str, list[str]]] = None + presence_penalty: float | None = None + max_tokens: int | None = None + stop: str | list[str] | None = None stream: bool = False @@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel): class ChatCompletionStreamResponseChoice(BaseModel): index: int delta: ChatCompletionMessage - finish_reason: Optional[Finish] = None + finish_reason: Finish | None = None class ChatCompletionResponseUsage(BaseModel): @@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel): class ScoreEvaluationRequest(BaseModel): model: str messages: list[str] - max_length: Optional[int] = None + max_length: int | None = None class ScoreEvaluationResponse(BaseModel): diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index adaaaa872..1e670b92c 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -14,9 +14,9 @@ import asyncio import os -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from threading import Thread -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from transformers import GenerationConfig, TextIteratorStreamer diff --git a/src/llamafactory/data/converter.py b/src/llamafactory/data/converter.py index ac3735e64..7ec6f12be 100644 --- a/src/llamafactory/data/converter.py +++ b/src/llamafactory/data/converter.py @@ -15,7 +15,7 @@ import json import os from abc import abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Union from ..extras import logging from .data_utils import Role @@ -40,7 +40,7 @@ class DatasetConverter: dataset_attr: "DatasetAttr" data_args: "DataArguments" - def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]: + def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None: r"""Optionally concatenate media path to media dir when loading from local disk.""" if medias is None: return None diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index d13bb8589..1c080f881 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -16,7 +16,6 @@ import json import re from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Optional, Union from typing_extensions import override @@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils @dataclass class Formatter(ABC): slots: SLOTS = field(default_factory=list) - tool_format: Optional[str] = None + tool_format: str | None = None @abstractmethod def apply(self, **kwargs) -> SLOTS: r"""Forms a list of slots according to the inputs to encode.""" ... - def extract(self, content: str) -> Union[str, list["FunctionCall"]]: + def extract(self, content: str) -> str | list["FunctionCall"]: r"""Extract a list of tuples from the response message if using tools. Each tuple consists of function name and function arguments. @@ -156,5 +155,5 @@ class ToolFormatter(Formatter): raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string @override - def extract(self, content: str) -> Union[str, list["FunctionCall"]]: + def extract(self, content: str) -> str | list["FunctionCall"]: return self.tool_utils.tool_extractor(content) diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index ad7667617..d3d44e6f3 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -162,13 +162,13 @@ def _load_single_dataset( def _get_merged_dataset( - dataset_names: Optional[list[str]], + dataset_names: list[str] | None, model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], return_dict: bool = False, -) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]: +) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None: r"""Return the merged datasets in the standard format.""" if dataset_names is None: return None @@ -227,7 +227,7 @@ def _get_dataset_processor( def _get_preprocessed_dataset( - dataset: Optional[Union["Dataset", "IterableDataset"]], + dataset: Union["Dataset", "IterableDataset"] | None, data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo", "kto"], @@ -235,7 +235,7 @@ def _get_preprocessed_dataset( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"] = None, is_eval: bool = False, -) -> Optional[Union["Dataset", "IterableDataset"]]: +) -> Union["Dataset", "IterableDataset"] | None: r"""Preprocesses the dataset, including format checking and tokenization.""" if dataset is None: return None diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 05acded2c..2ecf06a66 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -22,7 +22,7 @@ import re from copy import deepcopy from dataclasses import dataclass from io import BytesIO -from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union +from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union import numpy as np import torch @@ -32,7 +32,7 @@ from transformers.models.mllama.processing_mllama import ( convert_sparse_cross_attention_mask_to_dense, get_cross_attention_token_mask, ) -from typing_extensions import NotRequired, override +from typing_extensions import override from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than @@ -63,8 +63,8 @@ if TYPE_CHECKING: from transformers.video_processing_utils import BaseVideoProcessor class EncodedImage(TypedDict): - path: Optional[str] - bytes: Optional[bytes] + path: str | None + bytes: bytes | None ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject] VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] @@ -144,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool: @dataclass class MMPluginMixin: - image_token: Optional[str] - video_token: Optional[str] - audio_token: Optional[str] + image_token: str | None + video_token: str | None + audio_token: str | None expand_mm_tokens: bool = True def _validate_input( @@ -328,7 +328,7 @@ class MMPluginMixin: videos: list["VideoInput"], audios: list["AudioInput"], processor: "MMProcessor", - imglens: Optional[list[int]] = None, + imglens: list[int] | None = None, ) -> dict[str, "torch.Tensor"]: r"""Process visual inputs. @@ -426,13 +426,13 @@ class BasePlugin(MMPluginMixin): def process_token_ids( self, input_ids: list[int], - labels: Optional[list[int]], + labels: list[int] | None, images: list["ImageInput"], videos: list["VideoInput"], audios: list["AudioInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["MMProcessor"], - ) -> tuple[list[int], Optional[list[int]]]: + ) -> tuple[list[int], list[int] | None]: r"""Pre-process token ids after tokenization for VLMs.""" self._validate_input(processor, images, videos, audios) return input_ids, labels @@ -1305,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin): def process_token_ids( self, input_ids: list[int], - labels: Optional[list[int]], + labels: list[int] | None, images: list["ImageInput"], videos: list["VideoInput"], audios: list["AudioInput"], tokenizer: "PreTrainedTokenizer", processor: Optional["MMProcessor"], - ) -> tuple[list[int], Optional[list[int]]]: + ) -> tuple[list[int], list[int] | None]: self._validate_input(processor, images, videos, audios) num_images = len(images) image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token @@ -2126,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None: def get_mm_plugin( name: str, - image_token: Optional[str] = None, - video_token: Optional[str] = None, - audio_token: Optional[str] = None, + image_token: str | None = None, + video_token: str | None = None, + audio_token: str | None = None, **kwargs, ) -> "BasePlugin": r"""Get plugin for multimodal inputs.""" diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 3a865fd83..5209da649 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -15,7 +15,7 @@ import json import os from dataclasses import dataclass -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from huggingface_hub import hf_hub_download @@ -33,40 +33,40 @@ class DatasetAttr: formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca" ranking: bool = False # extra configs - subset: Optional[str] = None + subset: str | None = None split: str = "train" - folder: Optional[str] = None - num_samples: Optional[int] = None + folder: str | None = None + num_samples: int | None = None # common columns - system: Optional[str] = None - tools: Optional[str] = None - images: Optional[str] = None - videos: Optional[str] = None - audios: Optional[str] = None + system: str | None = None + tools: str | None = None + images: str | None = None + videos: str | None = None + audios: str | None = None # dpo columns - chosen: Optional[str] = None - rejected: Optional[str] = None - kto_tag: Optional[str] = None + chosen: str | None = None + rejected: str | None = None + kto_tag: str | None = None # alpaca columns - prompt: Optional[str] = "instruction" - query: Optional[str] = "input" - response: Optional[str] = "output" - history: Optional[str] = None + prompt: str | None = "instruction" + query: str | None = "input" + response: str | None = "output" + history: str | None = None # sharegpt columns - messages: Optional[str] = "conversations" + messages: str | None = "conversations" # sharegpt tags - role_tag: Optional[str] = "from" - content_tag: Optional[str] = "value" - user_tag: Optional[str] = "human" - assistant_tag: Optional[str] = "gpt" - observation_tag: Optional[str] = "observation" - function_tag: Optional[str] = "function_call" - system_tag: Optional[str] = "system" + role_tag: str | None = "from" + content_tag: str | None = "value" + user_tag: str | None = "human" + assistant_tag: str | None = "gpt" + observation_tag: str | None = "observation" + function_tag: str | None = "function_call" + system_tag: str | None = "system" def __repr__(self) -> str: return self.dataset_name - def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None: + def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None: setattr(self, key, obj.get(key, default)) def join(self, attr: dict[str, Any]) -> None: @@ -90,7 +90,7 @@ class DatasetAttr: self.set_attr(tag, attr["tags"]) -def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]: +def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]: r"""Get the attributes of the datasets.""" if dataset_names is None: dataset_names = [] diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index ca45c7d49..eb053150c 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -15,7 +15,6 @@ import os from collections import OrderedDict, defaultdict from enum import Enum, unique -from typing import Optional from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME @@ -154,7 +153,7 @@ class RopeScaling(str, Enum): def register_model_group( models: dict[str, dict[DownloadSource, str]], - template: Optional[str] = None, + template: str | None = None, multimodal: bool = False, ) -> None: for name, path in models.items(): diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index f234a807f..6997200a3 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None: library_root_logger.propagate = False -def get_logger(name: Optional[str] = None) -> "_Logger": +def get_logger(name: str | None = None) -> "_Logger": r"""Return a logger with the specified name. It it not supposed to be accessed externally.""" if name is None: name = _get_library_name() diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index e6844733e..921019a02 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -16,22 +16,22 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Literal @dataclass class DataArguments: r"""Arguments pertaining to what data we are going to input our model for training and evaluation.""" - template: Optional[str] = field( + template: str | None = field( default=None, metadata={"help": "Which template to use for constructing prompts in training and inference."}, ) - dataset: Optional[str] = field( + dataset: str | None = field( default=None, metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."}, ) - eval_dataset: Optional[str] = field( + eval_dataset: str | None = field( default=None, metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."}, ) @@ -39,7 +39,7 @@ class DataArguments: default="data", metadata={"help": "Path to the folder containing the datasets."}, ) - media_dir: Optional[str] = field( + media_dir: str | None = field( default=None, metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."}, ) @@ -67,7 +67,7 @@ class DataArguments: default="concat", metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, ) - interleave_probs: Optional[str] = field( + interleave_probs: str | None = field( default=None, metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}, ) @@ -79,15 +79,15 @@ class DataArguments: default=1000, metadata={"help": "The number of examples in one group in pre-processing."}, ) - preprocessing_num_workers: Optional[int] = field( + preprocessing_num_workers: int | None = field( default=None, metadata={"help": "The number of processes to use for the pre-processing."}, ) - max_samples: Optional[int] = field( + max_samples: int | None = field( default=None, metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, ) - eval_num_beams: Optional[int] = field( + eval_num_beams: int | None = field( default=None, metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, ) @@ -103,7 +103,7 @@ class DataArguments: default=False, metadata={"help": "Whether or not to evaluate on each dataset separately."}, ) - packing: Optional[bool] = field( + packing: bool | None = field( default=None, metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."}, ) @@ -111,19 +111,19 @@ class DataArguments: default=False, metadata={"help": "Enable sequence packing without cross-attention."}, ) - tool_format: Optional[str] = field( + tool_format: str | None = field( default=None, metadata={"help": "Tool format to use for constructing function calling examples."}, ) - default_system: Optional[str] = field( + default_system: str | None = field( default=None, metadata={"help": "Override the default system message in the template."}, ) - enable_thinking: Optional[bool] = field( + enable_thinking: bool | None = field( default=True, metadata={"help": "Whether or not to enable thinking mode for reasoning models."}, ) - tokenized_path: Optional[str] = field( + tokenized_path: str | None = field( default=None, metadata={ "help": ( diff --git a/src/llamafactory/hparams/evaluation_args.py b/src/llamafactory/hparams/evaluation_args.py index d92e8b1ea..eddc618ba 100644 --- a/src/llamafactory/hparams/evaluation_args.py +++ b/src/llamafactory/hparams/evaluation_args.py @@ -14,7 +14,7 @@ import os from dataclasses import dataclass, field -from typing import Literal, Optional +from typing import Literal from datasets import DownloadMode @@ -46,7 +46,7 @@ class EvaluationArguments: default=5, metadata={"help": "Number of examplars for few-shot learning."}, ) - save_dir: Optional[str] = field( + save_dir: str | None = field( default=None, metadata={"help": "Path to save the evaluation results."}, ) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index ef690d7bb..7ab2ce3bc 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import asdict, dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Literal @dataclass @@ -40,7 +40,7 @@ class FreezeArguments: ) }, ) - freeze_extra_modules: Optional[str] = field( + freeze_extra_modules: str | None = field( default=None, metadata={ "help": ( @@ -56,7 +56,7 @@ class FreezeArguments: class LoraArguments: r"""Arguments pertaining to the LoRA training.""" - additional_target: Optional[str] = field( + additional_target: str | None = field( default=None, metadata={ "help": ( @@ -66,7 +66,7 @@ class LoraArguments: ) }, ) - lora_alpha: Optional[int] = field( + lora_alpha: int | None = field( default=None, metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, ) @@ -88,7 +88,7 @@ class LoraArguments: ) }, ) - loraplus_lr_ratio: Optional[float] = field( + loraplus_lr_ratio: float | None = field( default=None, metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, ) @@ -126,7 +126,7 @@ class LoraArguments: class OFTArguments: r"""Arguments pertaining to the OFT training.""" - additional_target: Optional[str] = field( + additional_target: str | None = field( default=None, metadata={ "help": ( @@ -220,27 +220,27 @@ class RLHFArguments: default=False, metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, ) - ref_model: Optional[str] = field( + ref_model: str | None = field( default=None, metadata={"help": "Path to the reference model used for the PPO or DPO training."}, ) - ref_model_adapters: Optional[str] = field( + ref_model_adapters: str | None = field( default=None, metadata={"help": "Path to the adapters of the reference model."}, ) - ref_model_quantization_bit: Optional[int] = field( + ref_model_quantization_bit: int | None = field( default=None, metadata={"help": "The number of bits to quantize the reference model."}, ) - reward_model: Optional[str] = field( + reward_model: str | None = field( default=None, metadata={"help": "Path to the reward model used for the PPO training."}, ) - reward_model_adapters: Optional[str] = field( + reward_model_adapters: str | None = field( default=None, metadata={"help": "Path to the adapters of the reward model."}, ) - reward_model_quantization_bit: Optional[int] = field( + reward_model_quantization_bit: int | None = field( default=None, metadata={"help": "The number of bits to quantize the reward model."}, ) @@ -248,7 +248,7 @@ class RLHFArguments: default="lora", metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, ) - ld_alpha: Optional[float] = field( + ld_alpha: float | None = field( default=None, metadata={ "help": ( @@ -361,15 +361,15 @@ class BAdamArgument: default="layer", metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."}, ) - badam_start_block: Optional[int] = field( + badam_start_block: int | None = field( default=None, metadata={"help": "The starting block index for layer-wise BAdam."}, ) - badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( + badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field( default="ascending", metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, ) - badam_switch_interval: Optional[int] = field( + badam_switch_interval: int | None = field( default=50, metadata={ "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update." @@ -406,15 +406,15 @@ class SwanLabArguments: default=False, metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."}, ) - swanlab_project: Optional[str] = field( + swanlab_project: str | None = field( default="llamafactory", metadata={"help": "The project name in SwanLab."}, ) - swanlab_workspace: Optional[str] = field( + swanlab_workspace: str | None = field( default=None, metadata={"help": "The workspace name in SwanLab."}, ) - swanlab_run_name: Optional[str] = field( + swanlab_run_name: str | None = field( default=None, metadata={"help": "The experiment name in SwanLab."}, ) @@ -422,19 +422,19 @@ class SwanLabArguments: default="cloud", metadata={"help": "The mode of SwanLab."}, ) - swanlab_api_key: Optional[str] = field( + swanlab_api_key: str | None = field( default=None, metadata={"help": "The API key for SwanLab."}, ) - swanlab_logdir: Optional[str] = field( + swanlab_logdir: str | None = field( default=None, metadata={"help": "The log directory for SwanLab."}, ) - swanlab_lark_webhook_url: Optional[str] = field( + swanlab_lark_webhook_url: str | None = field( default=None, metadata={"help": "The Lark(飞书) webhook URL for SwanLab."}, ) - swanlab_lark_secret: Optional[str] = field( + swanlab_lark_secret: str | None = field( default=None, metadata={"help": "The Lark(飞书) secret for SwanLab."}, ) @@ -510,7 +510,7 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to disable the shuffling of the training set."}, ) - early_stopping_steps: Optional[int] = field( + early_stopping_steps: int | None = field( default=None, metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."}, ) @@ -530,11 +530,11 @@ class FinetuningArguments( return arg self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules) - self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules) + self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules) self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 self.lora_target: list[str] = split_arg(self.lora_target) self.oft_target: list[str] = split_arg(self.oft_target) - self.additional_target: Optional[list[str]] = split_arg(self.additional_target) + self.additional_target: list[str] | None = split_arg(self.additional_target) self.galore_target: list[str] = split_arg(self.galore_target) self.apollo_target: list[str] = split_arg(self.apollo_target) self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 02c100ec8..0d0be63e4 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -17,12 +17,11 @@ import json from dataclasses import asdict, dataclass, field, fields -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Self import torch from omegaconf import OmegaConf from transformers.training_args import _convert_str_dict -from typing_extensions import Self from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling from ..extras.logging import get_logger @@ -35,13 +34,13 @@ logger = get_logger(__name__) class BaseModelArguments: r"""Arguments pertaining to the model.""" - model_name_or_path: Optional[str] = field( + model_name_or_path: str | None = field( default=None, metadata={ "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." }, ) - adapter_name_or_path: Optional[str] = field( + adapter_name_or_path: str | None = field( default=None, metadata={ "help": ( @@ -50,11 +49,11 @@ class BaseModelArguments: ) }, ) - adapter_folder: Optional[str] = field( + adapter_folder: str | None = field( default=None, metadata={"help": "The folder containing the adapter weights to load."}, ) - cache_dir: Optional[str] = field( + cache_dir: str | None = field( default=None, metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, ) @@ -70,17 +69,17 @@ class BaseModelArguments: default=False, metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, ) - add_tokens: Optional[str] = field( + add_tokens: str | None = field( default=None, metadata={ "help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens." }, ) - add_special_tokens: Optional[str] = field( + add_special_tokens: str | None = field( default=None, metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, ) - new_special_tokens_config: Optional[str] = field( + new_special_tokens_config: str | None = field( default=None, metadata={ "help": ( @@ -110,7 +109,7 @@ class BaseModelArguments: default=True, metadata={"help": "Whether or not to use memory-efficient model loading."}, ) - rope_scaling: Optional[RopeScaling] = field( + rope_scaling: RopeScaling | None = field( default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) @@ -122,7 +121,7 @@ class BaseModelArguments: default=False, metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, ) - mixture_of_depths: Optional[Literal["convert", "load"]] = field( + mixture_of_depths: Literal["convert", "load"] | None = field( default=None, metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, ) @@ -138,7 +137,7 @@ class BaseModelArguments: default=False, metadata={"help": "Whether or not to enable liger kernel for faster training."}, ) - moe_aux_loss_coef: Optional[float] = field( + moe_aux_loss_coef: float | None = field( default=None, metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, ) @@ -182,15 +181,15 @@ class BaseModelArguments: default="auto", metadata={"help": "Data type for model weights and activations at inference."}, ) - hf_hub_token: Optional[str] = field( + hf_hub_token: str | None = field( default=None, metadata={"help": "Auth token to log in with Hugging Face Hub."}, ) - ms_hub_token: Optional[str] = field( + ms_hub_token: str | None = field( default=None, metadata={"help": "Auth token to log in with ModelScope Hub."}, ) - om_hub_token: Optional[str] = field( + om_hub_token: str | None = field( default=None, metadata={"help": "Auth token to log in with Modelers Hub."}, ) @@ -283,7 +282,7 @@ class QuantizationArguments: default=QuantizationMethod.BNB, metadata={"help": "Quantization method to use for on-the-fly quantization."}, ) - quantization_bit: Optional[int] = field( + quantization_bit: int | None = field( default=None, metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."}, ) @@ -295,7 +294,7 @@ class QuantizationArguments: default=True, metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."}, ) - quantization_device_map: Optional[Literal["auto"]] = field( + quantization_device_map: Literal["auto"] | None = field( default=None, metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, ) @@ -375,7 +374,7 @@ class ProcessorArguments: class ExportArguments: r"""Arguments pertaining to the model export.""" - export_dir: Optional[str] = field( + export_dir: str | None = field( default=None, metadata={"help": "Path to the directory to save the exported model."}, ) @@ -387,11 +386,11 @@ class ExportArguments: default="cpu", metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, ) - export_quantization_bit: Optional[int] = field( + export_quantization_bit: int | None = field( default=None, metadata={"help": "The number of bits to quantize the exported model."}, ) - export_quantization_dataset: Optional[str] = field( + export_quantization_dataset: str | None = field( default=None, metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, ) @@ -407,7 +406,7 @@ class ExportArguments: default=False, metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, ) - export_hub_model_id: Optional[str] = field( + export_hub_model_id: str | None = field( default=None, metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, ) @@ -437,7 +436,7 @@ class VllmArguments: default=32, metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, ) - vllm_config: Optional[Union[dict, str]] = field( + vllm_config: dict | str | None = field( default=None, metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."}, ) @@ -463,7 +462,7 @@ class SGLangArguments: default=-1, metadata={"help": "Tensor parallel size for the SGLang engine."}, ) - sglang_config: Optional[Union[dict, str]] = field( + sglang_config: dict | str | None = field( default=None, metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."}, ) @@ -487,21 +486,21 @@ class KTransformersArguments: default=False, metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."}, ) - kt_optimize_rule: Optional[str] = field( + kt_optimize_rule: str | None = field( default=None, metadata={ "help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/." }, ) - cpu_infer: Optional[int] = field( + cpu_infer: int | None = field( default=32, metadata={"help": "Number Of CPU Cores Used For Computation."}, ) - chunk_size: Optional[int] = field( + chunk_size: int | None = field( default=8192, metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."}, ) - mode: Optional[str] = field( + mode: str | None = field( default="normal", metadata={"help": "Normal Or Long_Context For Llama Models."}, ) @@ -539,17 +538,17 @@ class ModelArguments( The class on the most right will be displayed first. """ - compute_dtype: Optional[torch.dtype] = field( + compute_dtype: torch.dtype | None = field( default=None, init=False, metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."}, ) - device_map: Optional[Union[str, dict[str, Any]]] = field( + device_map: str | dict[str, Any] | None = field( default=None, init=False, metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."}, ) - model_max_length: Optional[int] = field( + model_max_length: int | None = field( default=None, init=False, metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."}, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index a3d9ddee2..5b262d68f 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -18,7 +18,7 @@ import os import sys from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Optional import torch import transformers @@ -65,7 +65,7 @@ else: _TRAIN_MCA_CLS = tuple() -def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]: +def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]: r"""Get arguments from the command line or a config file.""" if args is not None: return args @@ -83,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[ def _parse_args( - parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False + parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False ) -> tuple[Any]: args = read_args(args) if isinstance(args, dict): @@ -205,13 +205,13 @@ def _check_extra_dependencies( check_version("rouge_chinese", mandatory=True) -def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: +def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS: parser = HfArgumentParser(_TRAIN_ARGS) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) -def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS: +def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS: parser = HfArgumentParser(_TRAIN_MCA_ARGS) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") model_args, data_args, training_args, finetuning_args, generating_args = _parse_args( @@ -232,25 +232,25 @@ def _configure_mca_training_args(training_args, data_args, finetuning_args) -> N finetuning_args.use_mca = True -def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: +def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS: parser = HfArgumentParser(_INFER_ARGS) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) -def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: +def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS: parser = HfArgumentParser(_EVAL_ARGS) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) -def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments: +def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments: parser = HfArgumentParser(RayArguments) (ray_args,) = _parse_args(parser, args, allow_extra_keys=True) return ray_args -def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: +def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS: if is_env_enabled("USE_MCA"): model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args) else: @@ -473,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ return model_args, data_args, training_args, finetuning_args, generating_args -def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: +def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) # Setup logging @@ -508,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _ return model_args, data_args, finetuning_args, generating_args -def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS: +def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS: model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) # Setup logging diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index 46b40a2dd..86ac0802f 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -14,7 +14,7 @@ import json from dataclasses import dataclass, field -from typing import Literal, Optional, Union +from typing import Literal from transformers import Seq2SeqTrainingArguments from transformers.training_args import _convert_str_dict @@ -40,7 +40,7 @@ else: class RayArguments: r"""Arguments pertaining to the Ray training.""" - ray_run_name: Optional[str] = field( + ray_run_name: str | None = field( default=None, metadata={"help": "The training results will be saved at `/ray_run_name`."}, ) @@ -48,7 +48,7 @@ class RayArguments: default="./saves", metadata={"help": "The storage path to save training results to"}, ) - ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field( + ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field( default=None, metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."}, ) @@ -56,7 +56,7 @@ class RayArguments: default=1, metadata={"help": "The number of workers for Ray training. Default is 1 worker."}, ) - resources_per_worker: Union[dict, str] = field( + resources_per_worker: dict | str = field( default_factory=lambda: {"GPU": 1}, metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."}, ) @@ -64,7 +64,7 @@ class RayArguments: default="PACK", metadata={"help": "The placement strategy for Ray training. Default is PACK."}, ) - ray_init_kwargs: Optional[Union[dict, str]] = field( + ray_init_kwargs: dict | str | None = field( default=None, metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."}, ) diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 3e8341c1c..0ba7ec96a 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -20,9 +20,10 @@ import inspect import os +from collections.abc import Callable from functools import WRAPPER_ASSIGNMENTS, partial, wraps from types import MethodType -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch diff --git a/src/llamafactory/train/rm/metric.py b/src/llamafactory/train/rm/metric.py index a7c3c43f5..ae334cd9a 100644 --- a/src/llamafactory/train/rm/metric.py +++ b/src/llamafactory/train/rm/metric.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import numpy as np @@ -28,7 +28,7 @@ if TYPE_CHECKING: class ComputeAccuracy: r"""Compute reward accuracy and support `batch_eval_metrics`.""" - def _dump(self) -> Optional[dict[str, float]]: + def _dump(self) -> dict[str, float] | None: result = None if hasattr(self, "score_dict"): result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} @@ -39,7 +39,7 @@ class ComputeAccuracy: def __post_init__(self): self._dump() - def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]: + def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> dict[str, float] | None: chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) if not chosen_scores.shape: self.score_dict["accuracy"].append(chosen_scores > rejected_scores) diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 66bff5f19..60adb2ecc 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -19,9 +19,9 @@ import json import os -from collections.abc import Mapping +from collections.abc import Callable, Mapping from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import torch from transformers import Trainer diff --git a/src/llamafactory/v1/accelerator/helper.py b/src/llamafactory/v1/accelerator/helper.py index 76ed3ad46..04ded900c 100644 --- a/src/llamafactory/v1/accelerator/helper.py +++ b/src/llamafactory/v1/accelerator/helper.py @@ -25,10 +25,11 @@ Including: """ import os +from collections.abc import Callable from contextlib import contextmanager from enum import Enum, unique from functools import lru_cache, wraps -from typing import Callable, Optional +from typing import Optional import numpy as np import torch diff --git a/src/llamafactory/v1/accelerator/interface.py b/src/llamafactory/v1/accelerator/interface.py index 810776342..f8a5856d2 100644 --- a/src/llamafactory/v1/accelerator/interface.py +++ b/src/llamafactory/v1/accelerator/interface.py @@ -53,9 +53,9 @@ class DistributedStrategy: mp_replicate_size: int = 1 """Model parallel replicate size, default to 1.""" - mp_shard_size: Optional[int] = None + mp_shard_size: int | None = None """Model parallel shard size, default to world_size // mp_replicate_size.""" - dp_size: Optional[int] = None + dp_size: int | None = None """Data parallel size, default to world_size // cp_size.""" cp_size: int = 1 """Context parallel size, default to 1.""" @@ -115,7 +115,7 @@ class DistributedInterface: return cls._instance - def __init__(self, config: Optional[DistributedConfig] = None) -> None: + def __init__(self, config: DistributedConfig | None = None) -> None: if self._initialized: return @@ -165,7 +165,7 @@ class DistributedInterface: f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}" ) - def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]: + def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None: """Get device mesh for specified dimension.""" if dim is None: raise ValueError("dim must be specified.") @@ -176,14 +176,14 @@ class DistributedInterface: else: return self.model_device_mesh[dim.value] - def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]: + def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]: """Get process group for specified dimension.""" if self.model_device_mesh is None or dim is None: return None else: return self.get_device_mesh(dim).get_group() - def get_rank(self, dim: Optional[Dim] = None) -> int: + def get_rank(self, dim: Dim | None = None) -> int: """Get parallel rank for specified dimension.""" if self.model_device_mesh is None: return 0 @@ -192,7 +192,7 @@ class DistributedInterface: else: return self.get_device_mesh(dim).get_local_rank() - def get_world_size(self, dim: Optional[Dim] = None) -> int: + def get_world_size(self, dim: Dim | None = None) -> int: """Get parallel size for specified dimension.""" if self.model_device_mesh is None: return 1 @@ -209,7 +209,7 @@ class DistributedInterface: """Get parallel local world size.""" return self._local_world_size - def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor: + def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor: """Gather tensor across specified parallel group.""" if self.model_device_mesh is not None: return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim)) @@ -217,7 +217,7 @@ class DistributedInterface: return data def all_reduce( - self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP + self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP ) -> TensorLike: """Reduce tensor across specified parallel group.""" if self.model_device_mesh is not None: @@ -225,7 +225,7 @@ class DistributedInterface: else: return data - def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike: + def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike: """Broadcast tensor across specified parallel group.""" if self.model_device_mesh is not None: return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim)) diff --git a/src/llamafactory/v1/config/arg_parser.py b/src/llamafactory/v1/config/arg_parser.py index 05b4f69f3..adec3e4bb 100644 --- a/src/llamafactory/v1/config/arg_parser.py +++ b/src/llamafactory/v1/config/arg_parser.py @@ -15,7 +15,7 @@ import json import sys from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from omegaconf import OmegaConf from transformers import HfArgumentParser @@ -27,7 +27,7 @@ from .sample_args import SampleArguments from .training_args import TrainingArguments -InputArgument = Optional[Union[dict[str, Any], list[str]]] +InputArgument = dict[str, Any] | list[str] | None def validate_args( diff --git a/src/llamafactory/v1/config/arg_utils.py b/src/llamafactory/v1/config/arg_utils.py index 43c708b0a..5335cdbb5 100644 --- a/src/llamafactory/v1/config/arg_utils.py +++ b/src/llamafactory/v1/config/arg_utils.py @@ -18,7 +18,6 @@ import json from enum import Enum, unique -from typing import Optional, Union class PluginConfig(dict): @@ -33,7 +32,7 @@ class PluginConfig(dict): return self["name"] -PluginArgument = Optional[Union[PluginConfig, dict, str]] +PluginArgument = PluginConfig | dict | str | None @unique @@ -74,7 +73,7 @@ def _convert_str_dict(data: dict) -> dict: return data -def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]: +def get_plugin_config(config: PluginArgument) -> PluginConfig | None: """Get the plugin configuration from the argument value. Args: diff --git a/src/llamafactory/v1/config/data_args.py b/src/llamafactory/v1/config/data_args.py index 845bd1a82..c1bd5f23f 100644 --- a/src/llamafactory/v1/config/data_args.py +++ b/src/llamafactory/v1/config/data_args.py @@ -14,12 +14,11 @@ from dataclasses import dataclass, field -from typing import Optional @dataclass class DataArguments: - dataset: Optional[str] = field( + dataset: str | None = field( default=None, metadata={"help": "Path to the dataset."}, ) diff --git a/src/llamafactory/v1/config/model_args.py b/src/llamafactory/v1/config/model_args.py index 87d1a160c..370ed02d1 100644 --- a/src/llamafactory/v1/config/model_args.py +++ b/src/llamafactory/v1/config/model_args.py @@ -14,7 +14,6 @@ from dataclasses import dataclass, field -from typing import Optional from .arg_utils import ModelClass, PluginConfig, get_plugin_config @@ -36,15 +35,15 @@ class ModelArguments: default=ModelClass.LLM, metadata={"help": "Model class from Hugging Face."}, ) - peft_config: Optional[PluginConfig] = field( + peft_config: PluginConfig | None = field( default=None, metadata={"help": "PEFT configuration for the model."}, ) - kernel_config: Optional[PluginConfig] = field( + kernel_config: PluginConfig | None = field( default=None, metadata={"help": "Kernel configuration for the model."}, ) - quant_config: Optional[PluginConfig] = field( + quant_config: PluginConfig | None = field( default=None, metadata={"help": "Quantization configuration for the model."}, ) diff --git a/src/llamafactory/v1/config/training_args.py b/src/llamafactory/v1/config/training_args.py index d1afaf483..574ff015e 100644 --- a/src/llamafactory/v1/config/training_args.py +++ b/src/llamafactory/v1/config/training_args.py @@ -14,7 +14,6 @@ import os from dataclasses import dataclass, field -from typing import Optional from uuid import uuid4 from .arg_utils import PluginConfig, get_plugin_config @@ -42,7 +41,7 @@ class TrainingArguments: default=False, metadata={"help": "Use bf16 for training."}, ) - dist_config: Optional[PluginConfig] = field( + dist_config: PluginConfig | None = field( default=None, metadata={"help": "Distribution configuration for training."}, ) diff --git a/src/llamafactory/v1/core/data_engine.py b/src/llamafactory/v1/core/data_engine.py index 98d59d424..f0ebb00af 100644 --- a/src/llamafactory/v1/core/data_engine.py +++ b/src/llamafactory/v1/core/data_engine.py @@ -27,7 +27,7 @@ Get Data Sample: import os from collections.abc import Iterable -from typing import Any, Union +from typing import Any from huggingface_hub import hf_hub_download from omegaconf import OmegaConf @@ -134,7 +134,7 @@ class DataEngine(Dataset): else: return len(self.data_index) - def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]: + def __getitem__(self, index: int | Any) -> Sample | list[Sample]: """Get dataset item. Args: diff --git a/src/llamafactory/v1/plugins/data_plugins/converter.py b/src/llamafactory/v1/plugins/data_plugins/converter.py index b5778970d..07aae2dfe 100644 --- a/src/llamafactory/v1/plugins/data_plugins/converter.py +++ b/src/llamafactory/v1/plugins/data_plugins/converter.py @@ -13,9 +13,7 @@ # limitations under the License. -from typing import Any, Literal, TypedDict - -from typing_extensions import NotRequired +from typing import Any, Literal, NotRequired, TypedDict from ...utils import logging from ...utils.plugin import BasePlugin diff --git a/src/llamafactory/v1/plugins/data_plugins/loader.py b/src/llamafactory/v1/plugins/data_plugins/loader.py index 6329e3c33..9200ec34a 100644 --- a/src/llamafactory/v1/plugins/data_plugins/loader.py +++ b/src/llamafactory/v1/plugins/data_plugins/loader.py @@ -15,7 +15,7 @@ import os import random -from typing import Any, Literal, Optional, Union +from typing import Any, Literal from datasets import load_dataset @@ -70,7 +70,7 @@ class DataIndexPlugin(BasePlugin): """Plugin for adjusting dataset index.""" def adjust_data_index( - self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float] + self, data_index: list[tuple[str, int]], size: int | None, weight: float | None ) -> list[tuple[str, int]]: """Adjust dataset index by size and weight. @@ -95,8 +95,8 @@ class DataSelectorPlugin(BasePlugin): """Plugin for selecting dataset samples.""" def select( - self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any] - ) -> Union[tuple[str, int], list[tuple[str, int]]]: + self, data_index: list[tuple[str, int]], index: slice | list[int] | Any + ) -> tuple[str, int] | list[tuple[str, int]]: """Select dataset samples. Args: diff --git a/src/llamafactory/v1/plugins/data_plugins/template.py b/src/llamafactory/v1/plugins/data_plugins/template.py index 32ec6f378..96159142e 100644 --- a/src/llamafactory/v1/plugins/data_plugins/template.py +++ b/src/llamafactory/v1/plugins/data_plugins/template.py @@ -14,7 +14,6 @@ from dataclasses import dataclass -from typing import Union @dataclass @@ -32,7 +31,7 @@ class QwenTemplate: message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool thinking_template: str = "\n{content}\n\n\n" - def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str: + def _extract_content(self, content_data: str | list[dict[str, str]]) -> str: if isinstance(content_data, str): return content_data.strip() @@ -47,7 +46,7 @@ class QwenTemplate: return "" - def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str: + def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str: role = message["role"] content = self._extract_content(message.get("content", "")) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index 08dc5b9da..78a235074 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -13,7 +13,8 @@ # limitations under the License. from abc import ABC, ABCMeta, abstractmethod -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional from ....accelerator.helper import DeviceType, get_current_accelerator from ....utils.types import HFModel @@ -38,7 +39,7 @@ class KernelRegistry: self._initialized = True def register( - self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]] + self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None ) -> None: """Register a kernel implementation. @@ -56,7 +57,7 @@ class KernelRegistry: self._registry[kernel_type][device_type] = kernel_impl print(f"Registered kernel {kernel_type.name} for device {device_type.name}.") - def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]: + def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None: return self._registry.get(kernel_type, {}).get(device_type) @@ -105,9 +106,9 @@ class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta): auto_register: Set to False to disable automatic registration (default: True). """ - type: Optional[KernelType] = None - device: Optional[DeviceType] = None - kernel: Optional[Callable] = None + type: KernelType | None = None + device: DeviceType | None = None + kernel: Callable | None = None @classmethod @abstractmethod @@ -228,7 +229,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]: return discovered_kernels -def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel": +def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel": """Call the MetaKernel's `apply` to perform the replacement. Corresponding replacement logic is maintained inside each kernel; the only diff --git a/src/llamafactory/v1/plugins/model_plugins/peft.py b/src/llamafactory/v1/plugins/model_plugins/peft.py index 0dc29ba88..819b06d14 100644 --- a/src/llamafactory/v1/plugins/model_plugins/peft.py +++ b/src/llamafactory/v1/plugins/model_plugins/peft.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Literal, Optional, TypedDict +from typing import Literal, TypedDict from peft import LoraConfig, PeftModel, get_peft_model @@ -36,7 +36,7 @@ class FreezeConfigDict(TypedDict, total=False): """Plugin name.""" freeze_trainable_layers: int """Freeze trainable layers.""" - freeze_trainable_modules: Optional[list[str]] + freeze_trainable_modules: list[str] | None """Freeze trainable modules.""" diff --git a/src/llamafactory/v1/utils/dtype.py b/src/llamafactory/v1/utils/dtype.py index 09dfb0112..f3f262007 100644 --- a/src/llamafactory/v1/utils/dtype.py +++ b/src/llamafactory/v1/utils/dtype.py @@ -16,7 +16,6 @@ # limitations under the License. from contextlib import contextmanager -from typing import Union import torch from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device @@ -38,7 +37,7 @@ class DtypeInterface: _is_fp32_available = True @staticmethod - def is_available(precision: Union[str, torch.dtype]) -> bool: + def is_available(precision: str | torch.dtype) -> bool: if precision in DtypeRegistry.HALF_LIST: return DtypeInterface._is_fp16_available elif precision in DtypeRegistry.FLOAT_LIST: @@ -49,19 +48,19 @@ class DtypeInterface: raise RuntimeError(f"Unexpected precision: {precision}") @staticmethod - def is_fp16(precision: Union[str, torch.dtype]) -> bool: + def is_fp16(precision: str | torch.dtype) -> bool: return precision in DtypeRegistry.HALF_LIST @staticmethod - def is_fp32(precision: Union[str, torch.dtype]) -> bool: + def is_fp32(precision: str | torch.dtype) -> bool: return precision in DtypeRegistry.FLOAT_LIST @staticmethod - def is_bf16(precision: Union[str, torch.dtype]) -> bool: + def is_bf16(precision: str | torch.dtype) -> bool: return precision in DtypeRegistry.BFLOAT_LIST @staticmethod - def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype: + def to_dtype(precision: str | torch.dtype) -> torch.dtype: if precision in DtypeRegistry.HALF_LIST: return torch.float16 elif precision in DtypeRegistry.FLOAT_LIST: @@ -83,7 +82,7 @@ class DtypeInterface: raise RuntimeError(f"Unexpected precision: {precision}") @contextmanager - def set_dtype(self, precision: Union[str, torch.dtype]): + def set_dtype(self, precision: str | torch.dtype): original_dtype = torch.get_default_dtype() torch.set_default_dtype(self.to_dtype(precision)) try: diff --git a/src/llamafactory/v1/utils/logging.py b/src/llamafactory/v1/utils/logging.py index ebb890986..81bc53751 100644 --- a/src/llamafactory/v1/utils/logging.py +++ b/src/llamafactory/v1/utils/logging.py @@ -81,7 +81,7 @@ def _configure_library_root_logger() -> None: library_root_logger.propagate = False -def get_logger(name: Optional[str] = None) -> "_Logger": +def get_logger(name: str | None = None) -> "_Logger": """Return a logger with the specified name. It it not supposed to be accessed externally.""" if name is None: name = _get_library_name() diff --git a/src/llamafactory/v1/utils/plugin.py b/src/llamafactory/v1/utils/plugin.py index 2c1138a87..0e4bcdf8e 100644 --- a/src/llamafactory/v1/utils/plugin.py +++ b/src/llamafactory/v1/utils/plugin.py @@ -13,7 +13,7 @@ # limitations under the License. -from typing import Callable, Optional +from collections.abc import Callable from . import logging @@ -29,7 +29,7 @@ class BasePlugin: _registry: dict[str, Callable] = {} - def __init__(self, name: Optional[str] = None): + def __init__(self, name: str | None = None): """Initialize the plugin with a name. Args: diff --git a/src/llamafactory/v1/utils/types.py b/src/llamafactory/v1/utils/types.py index 5d1899609..bae334f4a 100644 --- a/src/llamafactory/v1/utils/types.py +++ b/src/llamafactory/v1/utils/types.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Literal, TypedDict, Union - -from typing_extensions import NotRequired +from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union if TYPE_CHECKING: diff --git a/src/llamafactory/webui/chatter.py b/src/llamafactory/webui/chatter.py index feeeaf3d8..e86505b4b 100644 --- a/src/llamafactory/webui/chatter.py +++ b/src/llamafactory/webui/chatter.py @@ -16,7 +16,7 @@ import json import os from collections.abc import Generator from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from transformers.utils import is_torch_npu_available @@ -81,7 +81,7 @@ class WebChatModel(ChatModel): def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None: self.manager = manager self.demo_mode = demo_mode - self.engine: Optional[BaseEngine] = None + self.engine: BaseEngine | None = None if not lazy_init: # read arguments from command line super().__init__() @@ -197,9 +197,9 @@ class WebChatModel(ChatModel): lang: str, system: str, tools: str, - image: Optional[Any], - video: Optional[Any], - audio: Optional[Any], + image: Any | None, + video: Any | None, + audio: Any | None, max_new_tokens: int, top_p: float, temperature: float, diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index a8e829f4e..cacf15182 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -17,7 +17,7 @@ import os import signal from collections import defaultdict from datetime import datetime -from typing import Any, Optional, Union +from typing import Any from psutil import Process from yaml import safe_dump, safe_load @@ -71,7 +71,7 @@ def _get_config_path() -> os.PathLike: return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) -def load_config() -> dict[str, Union[str, dict[str, Any]]]: +def load_config() -> dict[str, str | dict[str, Any]]: r"""Load user config if exists.""" try: with open(_get_config_path(), encoding="utf-8") as f: @@ -81,7 +81,7 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]: def save_config( - lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None + lang: str, hub_name: str | None = None, model_name: str | None = None, model_path: str | None = None ) -> None: r"""Save user config.""" os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) @@ -151,7 +151,7 @@ def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]: return {} -def load_args(config_path: str) -> Optional[dict[str, Any]]: +def load_args(config_path: str) -> dict[str, Any] | None: r"""Load the training configuration from config path.""" try: with open(config_path, encoding="utf-8") as f: diff --git a/src/llamafactory/webui/components/export.py b/src/llamafactory/webui/components/export.py index d153ffa6f..9597aa61b 100644 --- a/src/llamafactory/webui/components/export.py +++ b/src/llamafactory/webui/components/export.py @@ -14,7 +14,7 @@ import json from collections.abc import Generator -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from ...extras.constants import PEFT_METHODS from ...extras.misc import torch_gc @@ -37,7 +37,7 @@ if TYPE_CHECKING: GPTQ_BITS = ["8", "4", "3", "2"] -def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown": +def can_quantize(checkpoint_path: str | list[str]) -> "gr.Dropdown": if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0: return gr.Dropdown(value="none", interactive=False) else: @@ -49,7 +49,7 @@ def save_model( model_name: str, model_path: str, finetuning_type: str, - checkpoint_path: Union[str, list[str]], + checkpoint_path: str | list[str], template: str, export_size: int, export_quantization_bit: str, diff --git a/src/llamafactory/webui/control.py b/src/llamafactory/webui/control.py index f64b36994..ec99f4079 100644 --- a/src/llamafactory/webui/control.py +++ b/src/llamafactory/webui/control.py @@ -14,7 +14,7 @@ import json import os -from typing import Any, Optional +from typing import Any from transformers.trainer_utils import get_last_checkpoint @@ -206,7 +206,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S return gr.Dropdown(choices=datasets) -def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown": +def list_output_dirs(model_name: str | None, finetuning_type: str, current_time: str) -> "gr.Dropdown": r"""List all the directories that can resume from. Inputs: top.model_name, top.finetuning_type, train.current_time diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 0a6fc7c9a..9c772f5dc 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -17,7 +17,7 @@ import os from collections.abc import Generator from copy import deepcopy from subprocess import PIPE, Popen, TimeoutExpired -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any from transformers.utils import is_torch_npu_available @@ -59,7 +59,7 @@ class Runner: self.manager = manager self.demo_mode = demo_mode """ Resume """ - self.trainer: Optional[Popen] = None + self.trainer: Popen | None = None self.do_train = True self.running_data: dict[Component, Any] = None """ State """ diff --git a/tests/conftest.py b/tests/conftest.py index 71f7339f3..7220298fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers. """ import os -from typing import Optional import pytest from pytest import Config, FixtureRequest, Item, MonkeyPatch @@ -71,7 +70,7 @@ def _handle_slow_tests(items: list[Item]): item.add_marker(skip_slow) -def _get_visible_devices_env() -> Optional[str]: +def _get_visible_devices_env() -> str | None: """Return device visibility env var name.""" if CURRENT_DEVICE == "cuda": return "CUDA_VISIBLE_DEVICES" diff --git a/tests_v1/conftest.py b/tests_v1/conftest.py index 69d40fa5f..453a85e78 100644 --- a/tests_v1/conftest.py +++ b/tests_v1/conftest.py @@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers. """ import os -from typing import Optional import pytest from pytest import Config, FixtureRequest, Item, MonkeyPatch @@ -71,7 +70,7 @@ def _handle_slow_tests(items: list[Item]): item.add_marker(skip_slow) -def _get_visible_devices_env() -> Optional[str]: +def _get_visible_devices_env() -> str | None: """Return device visibility env var name.""" if CURRENT_DEVICE == "cuda": return "CUDA_VISIBLE_DEVICES"