mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[misc] lint code (#9395)
This commit is contained in:
@@ -24,7 +24,7 @@ class KernelType(str, Enum):
|
||||
|
||||
|
||||
class DeviceType(str, Enum):
|
||||
CPU = 'cpu'
|
||||
CUDA = 'cuda'
|
||||
NPU = 'npu'
|
||||
XPU = 'xpu'
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
NPU = "npu"
|
||||
XPU = "xpu"
|
||||
|
||||
@@ -27,14 +27,11 @@ def _npu_swiglu_forward(self, hidden_state):
|
||||
import torch_npu
|
||||
|
||||
return self.down_proj(
|
||||
torch_npu.npu_swiglu(
|
||||
torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1
|
||||
)
|
||||
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||
)
|
||||
|
||||
|
||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
|
||||
device = DeviceType.NPU
|
||||
kernel = _npu_swiglu_forward
|
||||
|
||||
@@ -43,7 +40,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
if not is_torch_npu_available():
|
||||
return model
|
||||
|
||||
@@ -51,7 +48,6 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||
for name, module in model.named_modules():
|
||||
# Match any module whose class name contains "RMSNorm"
|
||||
if re.search(swiglu_pattern, module.__class__.__name__):
|
||||
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
module.forward = types.MethodType(cls.kernel, module)
|
||||
|
||||
@@ -21,10 +21,10 @@ from .constants import DeviceType, KernelType
|
||||
|
||||
|
||||
class KernelRegistry:
|
||||
_instance: Optional['KernelRegistry'] = None
|
||||
_instance: Optional["KernelRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry':
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
@@ -36,10 +36,7 @@ class KernelRegistry:
|
||||
self._initialized = True
|
||||
|
||||
def register(
|
||||
self,
|
||||
kernel_type: KernelType,
|
||||
device_type: DeviceType,
|
||||
kernel_impl: Optional[Callable[..., Any]]
|
||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
|
||||
) -> None:
|
||||
"""Register a kernel implementation.
|
||||
|
||||
@@ -57,11 +54,7 @@ class KernelRegistry:
|
||||
self._registry[kernel_type][device_type] = kernel_impl
|
||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||
|
||||
def get_kernel(
|
||||
self,
|
||||
kernel_type: KernelType,
|
||||
device_type: DeviceType
|
||||
) -> Optional[Callable[..., Any]]:
|
||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
|
||||
return self._registry.get(kernel_type, {}).get(device_type)
|
||||
|
||||
|
||||
@@ -84,35 +77,30 @@ class MetaKernel(ABC):
|
||||
|
||||
|
||||
class MetaFlashAttentionKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaRMSNormKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaSwiGluKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaRoPEKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MetaMoEKernel(MetaKernel):
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||
raise NotImplementedError
|
||||
@@ -130,7 +118,7 @@ def discover_kernels(model: HFModel) -> list[MetaKernel]:
|
||||
return []
|
||||
|
||||
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel':
|
||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||
|
||||
Corresponding replacement logic is maintained inside each kernel; the only
|
||||
@@ -145,4 +133,6 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFMo
|
||||
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
|
||||
return kernel.apply(model, **kwargs)
|
||||
|
||||
raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.")
|
||||
raise ValueError(
|
||||
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
||||
)
|
||||
|
||||
@@ -65,7 +65,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
||||
for name, module in model.named_modules():
|
||||
# Match any module whose class name contains "RMSNorm"
|
||||
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||
|
||||
# Bind function as an instance method to preserve `self` semantics
|
||||
# and replace the original forward
|
||||
module.forward = types.MethodType(cls.kernel, module)
|
||||
|
||||
@@ -59,7 +59,7 @@ class NpuRoPEKernel(MetaRoPEKernel):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
@@ -96,7 +96,7 @@ class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||
|
||||
@classmethod
|
||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
||||
def apply(cls, model, **kwargs) -> "HFModel":
|
||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||
|
||||
This function iterates through the model's modules to find attention layers,
|
||||
|
||||
@@ -23,25 +23,25 @@ def get_available_accelerator():
|
||||
"""
|
||||
accelerator = torch.accelerator.current_accelerator()
|
||||
if accelerator is None:
|
||||
return torch.device('cpu')
|
||||
return torch.device("cpu")
|
||||
return accelerator
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_npu_available():
|
||||
return get_available_accelerator().type == 'npu'
|
||||
return get_available_accelerator().type == "npu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_cuda_available():
|
||||
return get_available_accelerator().type == 'cuda'
|
||||
return get_available_accelerator().type == "cuda"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_xpu_available():
|
||||
return get_available_accelerator().type == 'xpu'
|
||||
return get_available_accelerator().type == "xpu"
|
||||
|
||||
|
||||
@lru_cache
|
||||
def is_torch_mps_available():
|
||||
return get_available_accelerator().type == 'mps'
|
||||
return get_available_accelerator().type == "mps"
|
||||
|
||||
Reference in New Issue
Block a user