mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[deps] goodbye python 3.9 (#9677)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -25,10 +25,11 @@ Including:
|
||||
"""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@@ -53,9 +53,9 @@ class DistributedStrategy:
|
||||
|
||||
mp_replicate_size: int = 1
|
||||
"""Model parallel replicate size, default to 1."""
|
||||
mp_shard_size: Optional[int] = None
|
||||
mp_shard_size: int | None = None
|
||||
"""Model parallel shard size, default to world_size // mp_replicate_size."""
|
||||
dp_size: Optional[int] = None
|
||||
dp_size: int | None = None
|
||||
"""Data parallel size, default to world_size // cp_size."""
|
||||
cp_size: int = 1
|
||||
"""Context parallel size, default to 1."""
|
||||
@@ -115,7 +115,7 @@ class DistributedInterface:
|
||||
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
|
||||
def __init__(self, config: DistributedConfig | None = None) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
@@ -165,7 +165,7 @@ class DistributedInterface:
|
||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||
)
|
||||
|
||||
def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
|
||||
def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None:
|
||||
"""Get device mesh for specified dimension."""
|
||||
if dim is None:
|
||||
raise ValueError("dim must be specified.")
|
||||
@@ -176,14 +176,14 @@ class DistributedInterface:
|
||||
else:
|
||||
return self.model_device_mesh[dim.value]
|
||||
|
||||
def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
||||
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if self.model_device_mesh is None or dim is None:
|
||||
return None
|
||||
else:
|
||||
return self.get_device_mesh(dim).get_group()
|
||||
|
||||
def get_rank(self, dim: Optional[Dim] = None) -> int:
|
||||
def get_rank(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel rank for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
return 0
|
||||
@@ -192,7 +192,7 @@ class DistributedInterface:
|
||||
else:
|
||||
return self.get_device_mesh(dim).get_local_rank()
|
||||
|
||||
def get_world_size(self, dim: Optional[Dim] = None) -> int:
|
||||
def get_world_size(self, dim: Dim | None = None) -> int:
|
||||
"""Get parallel size for specified dimension."""
|
||||
if self.model_device_mesh is None:
|
||||
return 1
|
||||
@@ -209,7 +209,7 @@ class DistributedInterface:
|
||||
"""Get parallel local world size."""
|
||||
return self._local_world_size
|
||||
|
||||
def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
|
||||
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
@@ -217,7 +217,7 @@ class DistributedInterface:
|
||||
return data
|
||||
|
||||
def all_reduce(
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
||||
) -> TensorLike:
|
||||
"""Reduce tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
@@ -225,7 +225,7 @@ class DistributedInterface:
|
||||
else:
|
||||
return data
|
||||
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike:
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||
"""Broadcast tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
@@ -27,7 +27,7 @@ from .sample_args import SampleArguments
|
||||
from .training_args import TrainingArguments
|
||||
|
||||
|
||||
InputArgument = Optional[Union[dict[str, Any], list[str]]]
|
||||
InputArgument = dict[str, Any] | list[str] | None
|
||||
|
||||
|
||||
def validate_args(
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
|
||||
import json
|
||||
from enum import Enum, unique
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
class PluginConfig(dict):
|
||||
@@ -33,7 +32,7 @@ class PluginConfig(dict):
|
||||
return self["name"]
|
||||
|
||||
|
||||
PluginArgument = Optional[Union[PluginConfig, dict, str]]
|
||||
PluginArgument = PluginConfig | dict | str | None
|
||||
|
||||
|
||||
@unique
|
||||
@@ -74,7 +73,7 @@ def _convert_str_dict(data: dict) -> dict:
|
||||
return data
|
||||
|
||||
|
||||
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
|
||||
def get_plugin_config(config: PluginArgument) -> PluginConfig | None:
|
||||
"""Get the plugin configuration from the argument value.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -14,12 +14,11 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
dataset: Optional[str] = field(
|
||||
dataset: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the dataset."},
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||
|
||||
@@ -36,15 +35,15 @@ class ModelArguments:
|
||||
default=ModelClass.LLM,
|
||||
metadata={"help": "Model class from Hugging Face."},
|
||||
)
|
||||
peft_config: Optional[PluginConfig] = field(
|
||||
peft_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "PEFT configuration for the model."},
|
||||
)
|
||||
kernel_config: Optional[PluginConfig] = field(
|
||||
kernel_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Kernel configuration for the model."},
|
||||
)
|
||||
quant_config: Optional[PluginConfig] = field(
|
||||
quant_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Quantization configuration for the model."},
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from .arg_utils import PluginConfig, get_plugin_config
|
||||
@@ -42,7 +41,7 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Use bf16 for training."},
|
||||
)
|
||||
dist_config: Optional[PluginConfig] = field(
|
||||
dist_config: PluginConfig | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Distribution configuration for training."},
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ Get Data Sample:
|
||||
|
||||
import os
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from omegaconf import OmegaConf
|
||||
@@ -134,7 +134,7 @@ class DataEngine(Dataset):
|
||||
else:
|
||||
return len(self.data_index)
|
||||
|
||||
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
|
||||
def __getitem__(self, index: int | Any) -> Sample | list[Sample]:
|
||||
"""Get dataset item.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
from typing import Any, Literal, NotRequired, TypedDict
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
@@ -70,7 +70,7 @@ class DataIndexPlugin(BasePlugin):
|
||||
"""Plugin for adjusting dataset index."""
|
||||
|
||||
def adjust_data_index(
|
||||
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
|
||||
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Adjust dataset index by size and weight.
|
||||
|
||||
@@ -95,8 +95,8 @@ class DataSelectorPlugin(BasePlugin):
|
||||
"""Plugin for selecting dataset samples."""
|
||||
|
||||
def select(
|
||||
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
|
||||
) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
||||
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||
"""Select dataset samples.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -32,7 +31,7 @@ class QwenTemplate:
|
||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||
|
||||
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str:
|
||||
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
|
||||
if isinstance(content_data, str):
|
||||
return content_data.strip()
|
||||
|
||||
@@ -47,7 +46,7 @@ class QwenTemplate:
|
||||
|
||||
return ""
|
||||
|
||||
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str:
|
||||
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
|
||||
role = message["role"]
|
||||
content = self._extract_content(message.get("content", ""))
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
@@ -38,7 +39,7 @@ class KernelRegistry:
|
||||
self._initialized = True
|
||||
|
||||
def register(
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
|
||||
) -> None:
|
||||
"""Register a kernel implementation.
|
||||
|
||||
@@ -56,7 +57,7 @@ class KernelRegistry:
|
||||
self._registry[kernel_type][device_type] = kernel_impl
|
||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
|
||||
return self._registry.get(kernel_type, {}).get(device_type)
|
||||
|
||||
|
||||
@@ -105,9 +106,9 @@ class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||
auto_register: Set to False to disable automatic registration (default: True).
|
||||
"""
|
||||
|
||||
type: Optional[KernelType] = None
|
||||
device: Optional[DeviceType] = None
|
||||
kernel: Optional[Callable] = None
|
||||
type: KernelType | None = None
|
||||
device: DeviceType | None = None
|
||||
kernel: Callable | None = None
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -228,7 +229,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
@@ -36,7 +36,7 @@ class FreezeConfigDict(TypedDict, total=False):
|
||||
"""Plugin name."""
|
||||
freeze_trainable_layers: int
|
||||
"""Freeze trainable layers."""
|
||||
freeze_trainable_modules: Optional[list[str]]
|
||||
freeze_trainable_modules: list[str] | None
|
||||
"""Freeze trainable modules."""
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
|
||||
@@ -38,7 +37,7 @@ class DtypeInterface:
|
||||
_is_fp32_available = True
|
||||
|
||||
@staticmethod
|
||||
def is_available(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_available(precision: str | torch.dtype) -> bool:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return DtypeInterface._is_fp16_available
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
@@ -49,19 +48,19 @@ class DtypeInterface:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@staticmethod
|
||||
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_fp16(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.HALF_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_fp32(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.FLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
|
||||
def is_bf16(precision: str | torch.dtype) -> bool:
|
||||
return precision in DtypeRegistry.BFLOAT_LIST
|
||||
|
||||
@staticmethod
|
||||
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
|
||||
def to_dtype(precision: str | torch.dtype) -> torch.dtype:
|
||||
if precision in DtypeRegistry.HALF_LIST:
|
||||
return torch.float16
|
||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||
@@ -83,7 +82,7 @@ class DtypeInterface:
|
||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||
|
||||
@contextmanager
|
||||
def set_dtype(self, precision: Union[str, torch.dtype]):
|
||||
def set_dtype(self, precision: str | torch.dtype):
|
||||
original_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(self.to_dtype(precision))
|
||||
try:
|
||||
|
||||
@@ -81,7 +81,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||
def get_logger(name: str | None = None) -> "_Logger":
|
||||
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
from . import logging
|
||||
|
||||
@@ -29,7 +29,7 @@ class BasePlugin:
|
||||
|
||||
_registry: dict[str, Callable] = {}
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
def __init__(self, name: str | None = None):
|
||||
"""Initialize the plugin with a name.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -12,9 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
Reference in New Issue
Block a user