[deps] goodbye python 3.9 (#9677)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com>
Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
Copilot
2025-12-27 02:50:44 +08:00
committed by GitHub
parent b44f651e09
commit eceec8ab69
48 changed files with 267 additions and 284 deletions

View File

@@ -25,10 +25,11 @@ Including:
"""
import os
from collections.abc import Callable
from contextlib import contextmanager
from enum import Enum, unique
from functools import lru_cache, wraps
from typing import Callable, Optional
from typing import Optional
import numpy as np
import torch

View File

@@ -53,9 +53,9 @@ class DistributedStrategy:
mp_replicate_size: int = 1
"""Model parallel replicate size, default to 1."""
mp_shard_size: Optional[int] = None
mp_shard_size: int | None = None
"""Model parallel shard size, default to world_size // mp_replicate_size."""
dp_size: Optional[int] = None
dp_size: int | None = None
"""Data parallel size, default to world_size // cp_size."""
cp_size: int = 1
"""Context parallel size, default to 1."""
@@ -115,7 +115,7 @@ class DistributedInterface:
return cls._instance
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
def __init__(self, config: DistributedConfig | None = None) -> None:
if self._initialized:
return
@@ -165,7 +165,7 @@ class DistributedInterface:
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
)
def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None:
"""Get device mesh for specified dimension."""
if dim is None:
raise ValueError("dim must be specified.")
@@ -176,14 +176,14 @@ class DistributedInterface:
else:
return self.model_device_mesh[dim.value]
def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
"""Get process group for specified dimension."""
if self.model_device_mesh is None or dim is None:
return None
else:
return self.get_device_mesh(dim).get_group()
def get_rank(self, dim: Optional[Dim] = None) -> int:
def get_rank(self, dim: Dim | None = None) -> int:
"""Get parallel rank for specified dimension."""
if self.model_device_mesh is None:
return 0
@@ -192,7 +192,7 @@ class DistributedInterface:
else:
return self.get_device_mesh(dim).get_local_rank()
def get_world_size(self, dim: Optional[Dim] = None) -> int:
def get_world_size(self, dim: Dim | None = None) -> int:
"""Get parallel size for specified dimension."""
if self.model_device_mesh is None:
return 1
@@ -209,7 +209,7 @@ class DistributedInterface:
"""Get parallel local world size."""
return self._local_world_size
def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
"""Gather tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
@@ -217,7 +217,7 @@ class DistributedInterface:
return data
def all_reduce(
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
) -> TensorLike:
"""Reduce tensor across specified parallel group."""
if self.model_device_mesh is not None:
@@ -225,7 +225,7 @@ class DistributedInterface:
else:
return data
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike:
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
"""Broadcast tensor across specified parallel group."""
if self.model_device_mesh is not None:
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))

View File

@@ -15,7 +15,7 @@
import json
import sys
from pathlib import Path
from typing import Any, Optional, Union
from typing import Any
from omegaconf import OmegaConf
from transformers import HfArgumentParser
@@ -27,7 +27,7 @@ from .sample_args import SampleArguments
from .training_args import TrainingArguments
InputArgument = Optional[Union[dict[str, Any], list[str]]]
InputArgument = dict[str, Any] | list[str] | None
def validate_args(

View File

@@ -18,7 +18,6 @@
import json
from enum import Enum, unique
from typing import Optional, Union
class PluginConfig(dict):
@@ -33,7 +32,7 @@ class PluginConfig(dict):
return self["name"]
PluginArgument = Optional[Union[PluginConfig, dict, str]]
PluginArgument = PluginConfig | dict | str | None
@unique
@@ -74,7 +73,7 @@ def _convert_str_dict(data: dict) -> dict:
return data
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
def get_plugin_config(config: PluginArgument) -> PluginConfig | None:
"""Get the plugin configuration from the argument value.
Args:

View File

@@ -14,12 +14,11 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class DataArguments:
dataset: Optional[str] = field(
dataset: str | None = field(
default=None,
metadata={"help": "Path to the dataset."},
)

View File

@@ -14,7 +14,6 @@
from dataclasses import dataclass, field
from typing import Optional
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
@@ -36,15 +35,15 @@ class ModelArguments:
default=ModelClass.LLM,
metadata={"help": "Model class from Hugging Face."},
)
peft_config: Optional[PluginConfig] = field(
peft_config: PluginConfig | None = field(
default=None,
metadata={"help": "PEFT configuration for the model."},
)
kernel_config: Optional[PluginConfig] = field(
kernel_config: PluginConfig | None = field(
default=None,
metadata={"help": "Kernel configuration for the model."},
)
quant_config: Optional[PluginConfig] = field(
quant_config: PluginConfig | None = field(
default=None,
metadata={"help": "Quantization configuration for the model."},
)

View File

@@ -14,7 +14,6 @@
import os
from dataclasses import dataclass, field
from typing import Optional
from uuid import uuid4
from .arg_utils import PluginConfig, get_plugin_config
@@ -42,7 +41,7 @@ class TrainingArguments:
default=False,
metadata={"help": "Use bf16 for training."},
)
dist_config: Optional[PluginConfig] = field(
dist_config: PluginConfig | None = field(
default=None,
metadata={"help": "Distribution configuration for training."},
)

View File

@@ -27,7 +27,7 @@ Get Data Sample:
import os
from collections.abc import Iterable
from typing import Any, Union
from typing import Any
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
@@ -134,7 +134,7 @@ class DataEngine(Dataset):
else:
return len(self.data_index)
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
def __getitem__(self, index: int | Any) -> Sample | list[Sample]:
"""Get dataset item.
Args:

View File

@@ -13,9 +13,7 @@
# limitations under the License.
from typing import Any, Literal, TypedDict
from typing_extensions import NotRequired
from typing import Any, Literal, NotRequired, TypedDict
from ...utils import logging
from ...utils.plugin import BasePlugin

View File

@@ -15,7 +15,7 @@
import os
import random
from typing import Any, Literal, Optional, Union
from typing import Any, Literal
from datasets import load_dataset
@@ -70,7 +70,7 @@ class DataIndexPlugin(BasePlugin):
"""Plugin for adjusting dataset index."""
def adjust_data_index(
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
) -> list[tuple[str, int]]:
"""Adjust dataset index by size and weight.
@@ -95,8 +95,8 @@ class DataSelectorPlugin(BasePlugin):
"""Plugin for selecting dataset samples."""
def select(
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
) -> Union[tuple[str, int], list[tuple[str, int]]]:
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
) -> tuple[str, int] | list[tuple[str, int]]:
"""Select dataset samples.
Args:

View File

@@ -14,7 +14,6 @@
from dataclasses import dataclass
from typing import Union
@dataclass
@@ -32,7 +31,7 @@ class QwenTemplate:
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
thinking_template: str = "<think>\n{content}\n</think>\n\n"
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str:
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
if isinstance(content_data, str):
return content_data.strip()
@@ -47,7 +46,7 @@ class QwenTemplate:
return ""
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str:
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
role = message["role"]
content = self._extract_content(message.get("content", ""))

View File

@@ -13,7 +13,8 @@
# limitations under the License.
from abc import ABC, ABCMeta, abstractmethod
from typing import Any, Callable, Optional, Union
from collections.abc import Callable
from typing import Any, Optional
from ....accelerator.helper import DeviceType, get_current_accelerator
from ....utils.types import HFModel
@@ -38,7 +39,7 @@ class KernelRegistry:
self._initialized = True
def register(
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
) -> None:
"""Register a kernel implementation.
@@ -56,7 +57,7 @@ class KernelRegistry:
self._registry[kernel_type][device_type] = kernel_impl
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
return self._registry.get(kernel_type, {}).get(device_type)
@@ -105,9 +106,9 @@ class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
auto_register: Set to False to disable automatic registration (default: True).
"""
type: Optional[KernelType] = None
device: Optional[DeviceType] = None
kernel: Optional[Callable] = None
type: KernelType | None = None
device: DeviceType | None = None
kernel: Callable | None = None
@classmethod
@abstractmethod
@@ -228,7 +229,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
return discovered_kernels
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
"""Call the MetaKernel's `apply` to perform the replacement.
Corresponding replacement logic is maintained inside each kernel; the only

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, Optional, TypedDict
from typing import Literal, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model
@@ -36,7 +36,7 @@ class FreezeConfigDict(TypedDict, total=False):
"""Plugin name."""
freeze_trainable_layers: int
"""Freeze trainable layers."""
freeze_trainable_modules: Optional[list[str]]
freeze_trainable_modules: list[str] | None
"""Freeze trainable modules."""

View File

@@ -16,7 +16,6 @@
# limitations under the License.
from contextlib import contextmanager
from typing import Union
import torch
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
@@ -38,7 +37,7 @@ class DtypeInterface:
_is_fp32_available = True
@staticmethod
def is_available(precision: Union[str, torch.dtype]) -> bool:
def is_available(precision: str | torch.dtype) -> bool:
if precision in DtypeRegistry.HALF_LIST:
return DtypeInterface._is_fp16_available
elif precision in DtypeRegistry.FLOAT_LIST:
@@ -49,19 +48,19 @@ class DtypeInterface:
raise RuntimeError(f"Unexpected precision: {precision}")
@staticmethod
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
def is_fp16(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.HALF_LIST
@staticmethod
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
def is_fp32(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.FLOAT_LIST
@staticmethod
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
def is_bf16(precision: str | torch.dtype) -> bool:
return precision in DtypeRegistry.BFLOAT_LIST
@staticmethod
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
def to_dtype(precision: str | torch.dtype) -> torch.dtype:
if precision in DtypeRegistry.HALF_LIST:
return torch.float16
elif precision in DtypeRegistry.FLOAT_LIST:
@@ -83,7 +82,7 @@ class DtypeInterface:
raise RuntimeError(f"Unexpected precision: {precision}")
@contextmanager
def set_dtype(self, precision: Union[str, torch.dtype]):
def set_dtype(self, precision: str | torch.dtype):
original_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.to_dtype(precision))
try:

View File

@@ -81,7 +81,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "_Logger":
def get_logger(name: str | None = None) -> "_Logger":
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
if name is None:
name = _get_library_name()

View File

@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Callable, Optional
from collections.abc import Callable
from . import logging
@@ -29,7 +29,7 @@ class BasePlugin:
_registry: dict[str, Callable] = {}
def __init__(self, name: Optional[str] = None):
def __init__(self, name: str | None = None):
"""Initialize the plugin with a name.
Args:

View File

@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Literal, TypedDict, Union
from typing_extensions import NotRequired
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
if TYPE_CHECKING: