mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-08 04:35:58 +08:00
[deps] goodbye python 3.9 (#9677)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
@@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, ABCMeta, abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||
from ....utils.types import HFModel
|
||||
@@ -38,7 +39,7 @@ class KernelRegistry:
|
||||
self._initialized = True
|
||||
|
||||
def register(
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
|
||||
) -> None:
|
||||
"""Register a kernel implementation.
|
||||
|
||||
@@ -56,7 +57,7 @@ class KernelRegistry:
|
||||
self._registry[kernel_type][device_type] = kernel_impl
|
||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
|
||||
return self._registry.get(kernel_type, {}).get(device_type)
|
||||
|
||||
|
||||
@@ -105,9 +106,9 @@ class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
||||
auto_register: Set to False to disable automatic registration (default: True).
|
||||
"""
|
||||
|
||||
type: Optional[KernelType] = None
|
||||
device: Optional[DeviceType] = None
|
||||
kernel: Optional[Callable] = None
|
||||
type: KernelType | None = None
|
||||
device: DeviceType | None = None
|
||||
kernel: Callable | None = None
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
@@ -228,7 +229,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
||||
return discovered_kernels
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, Optional, TypedDict
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
@@ -36,7 +36,7 @@ class FreezeConfigDict(TypedDict, total=False):
|
||||
"""Plugin name."""
|
||||
freeze_trainable_layers: int
|
||||
"""Freeze trainable layers."""
|
||||
freeze_trainable_modules: Optional[list[str]]
|
||||
freeze_trainable_modules: list[str] | None
|
||||
"""Freeze trainable modules."""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user