mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 02:12:14 +08:00
[misc] lint code (#9395)
This commit is contained in:
parent
215580c77d
commit
3ae15da9c0
@ -137,7 +137,6 @@ def _load_single_dataset(
|
|||||||
cache_dir=model_args.cache_dir,
|
cache_dir=model_args.cache_dir,
|
||||||
token=model_args.hf_hub_token,
|
token=model_args.hf_hub_token,
|
||||||
num_proc=data_args.preprocessing_num_workers,
|
num_proc=data_args.preprocessing_num_workers,
|
||||||
trust_remote_code=model_args.trust_remote_code,
|
|
||||||
streaming=data_args.streaming and dataset_attr.load_from != "file",
|
streaming=data_args.streaming and dataset_attr.load_from != "file",
|
||||||
)
|
)
|
||||||
if data_args.streaming and dataset_attr.load_from == "file":
|
if data_args.streaming and dataset_attr.load_from == "file":
|
||||||
|
|||||||
@ -70,7 +70,6 @@ if TYPE_CHECKING:
|
|||||||
from transformers.image_processing_utils import BaseImageProcessor
|
from transformers.image_processing_utils import BaseImageProcessor
|
||||||
from transformers.video_processing_utils import BaseVideoProcessor
|
from transformers.video_processing_utils import BaseVideoProcessor
|
||||||
|
|
||||||
|
|
||||||
class EncodedImage(TypedDict):
|
class EncodedImage(TypedDict):
|
||||||
path: Optional[str]
|
path: Optional[str]
|
||||||
bytes: Optional[bytes]
|
bytes: Optional[bytes]
|
||||||
|
|||||||
@ -56,7 +56,18 @@ LAYERNORM_NAMES = {"norm", "ln"}
|
|||||||
|
|
||||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||||
|
|
||||||
MCA_SUPPORTED_MODELS = {"deepseek_v3", "llama", "mistral", "mixtral", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "qwen3_next"}
|
MCA_SUPPORTED_MODELS = {
|
||||||
|
"deepseek_v3",
|
||||||
|
"llama",
|
||||||
|
"mistral",
|
||||||
|
"mixtral",
|
||||||
|
"qwen2",
|
||||||
|
"qwen2_vl",
|
||||||
|
"qwen2_5_vl",
|
||||||
|
"qwen3",
|
||||||
|
"qwen3_moe",
|
||||||
|
"qwen3_next",
|
||||||
|
}
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora", "oft"]
|
METHODS = ["full", "freeze", "lora", "oft"]
|
||||||
|
|
||||||
|
|||||||
@ -475,7 +475,12 @@ class FinetuningArguments(
|
|||||||
)
|
)
|
||||||
use_mca: bool = field(
|
use_mca: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use MCA (Megatron Core Adapter) training. Controlled by USE_MCA environment variable."},
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Whether or not to use MCA (Megatron Core Adapter) training. "
|
||||||
|
"Controlled by USE_MCA environment variable."
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
use_muon: bool = field(
|
use_muon: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
|
|||||||
@ -55,12 +55,16 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
|
|||||||
|
|
||||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||||
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||||
|
|
||||||
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
_TRAIN_MCA_CLS = tuple[ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
_TRAIN_MCA_CLS = tuple[
|
||||||
|
ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
_TRAIN_MCA_ARGS = []
|
_TRAIN_MCA_ARGS = []
|
||||||
_TRAIN_MCA_CLS = tuple()
|
_TRAIN_MCA_CLS = tuple()
|
||||||
|
|
||||||
|
|
||||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||||
r"""Get arguments from the command line or a config file."""
|
r"""Get arguments from the command line or a config file."""
|
||||||
if args is not None:
|
if args is not None:
|
||||||
|
|||||||
@ -20,17 +20,18 @@ from transformers import Seq2SeqTrainingArguments
|
|||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
|
|
||||||
from ..extras.misc import is_env_enabled, use_ray
|
from ..extras.misc import is_env_enabled, use_ray
|
||||||
|
from ..extras.packages import is_mcore_adapter_available
|
||||||
|
|
||||||
|
|
||||||
if is_env_enabled("USE_MCA"):
|
if is_env_enabled("USE_MCA"):
|
||||||
try:
|
if not is_mcore_adapter_available():
|
||||||
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
|
||||||
BaseTrainingArguments = McaSeq2SeqTrainingArguments
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"mcore_adapter is required when USE_MCA=1.",
|
"mcore_adapter is required when USE_MCA=1. Please install `mcore_adapter` and its dependencies."
|
||||||
"Please install `mcore_adapter` and its dependencies."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||||
|
|
||||||
|
BaseTrainingArguments = McaSeq2SeqTrainingArguments
|
||||||
else:
|
else:
|
||||||
BaseTrainingArguments = Seq2SeqTrainingArguments
|
BaseTrainingArguments = Seq2SeqTrainingArguments
|
||||||
|
|
||||||
|
|||||||
@ -54,8 +54,7 @@ def launch():
|
|||||||
)
|
)
|
||||||
|
|
||||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||||
if is_env_enabled("USE_MCA"):
|
if is_env_enabled("USE_MCA"): # force use torchrun
|
||||||
# force use torchrun
|
|
||||||
os.environ["FORCE_TORCHRUN"] = "1"
|
os.environ["FORCE_TORCHRUN"] = "1"
|
||||||
|
|
||||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||||
|
|||||||
@ -16,4 +16,3 @@ from .workflow import run_dpo, run_pt, run_sft
|
|||||||
|
|
||||||
|
|
||||||
__all__ = ["run_dpo", "run_pt", "run_sft"]
|
__all__ = ["run_dpo", "run_pt", "run_sft"]
|
||||||
|
|
||||||
|
|||||||
@ -75,12 +75,17 @@ def _data_collator_wrapper(data_collator: Any):
|
|||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _check_model_support(model_args: ModelArguments):
|
def _check_model_support(model_args: ModelArguments):
|
||||||
from transformers import AutoConfig as HfAutoConfig
|
from transformers import AutoConfig as HfAutoConfig
|
||||||
config = HfAutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
|
|
||||||
|
config = HfAutoConfig.from_pretrained(
|
||||||
|
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||||
|
)
|
||||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
if config.model_type not in MCA_SUPPORTED_MODELS:
|
||||||
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
|
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
|
||||||
|
|
||||||
|
|
||||||
def run_pt(
|
def run_pt(
|
||||||
model_args: ModelArguments,
|
model_args: ModelArguments,
|
||||||
data_args: DataArguments,
|
data_args: DataArguments,
|
||||||
@ -161,22 +166,23 @@ def run_sft(
|
|||||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
|
||||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_vision_tower:
|
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
|
||||||
|
params_to_freeze = []
|
||||||
|
if finetuning_args.freeze_vision_tower:
|
||||||
|
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||||
|
|
||||||
|
if finetuning_args.freeze_multi_modal_projector:
|
||||||
|
params_to_freeze.extend(["multi_modal_projector"])
|
||||||
|
|
||||||
|
if finetuning_args.freeze_language_model:
|
||||||
|
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||||
|
|
||||||
|
if params_to_freeze:
|
||||||
for name, p in model.named_parameters():
|
for name, p in model.named_parameters():
|
||||||
if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]):
|
if any(name.startswith(k) for k in params_to_freeze):
|
||||||
p.requires_grad_(False)
|
|
||||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_multi_modal_projector:
|
|
||||||
for name, p in model.named_parameters():
|
|
||||||
if any(name.startswith(k) for k in ["multi_modal_projector"]):
|
|
||||||
p.requires_grad_(False)
|
|
||||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_language_model:
|
|
||||||
for name, p in model.named_parameters():
|
|
||||||
if any(name.startswith(k) for k in ["embedding", "decoder", "output_layer"]):
|
|
||||||
p.requires_grad_(False)
|
p.requires_grad_(False)
|
||||||
|
|
||||||
pad_to_max = (
|
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||||
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
|
||||||
)
|
|
||||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||||
template=template,
|
template=template,
|
||||||
padding="max_length" if pad_to_max else "longest",
|
padding="max_length" if pad_to_max else "longest",
|
||||||
@ -239,9 +245,7 @@ def run_dpo(
|
|||||||
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||||
data_args.cutoff_len -= 1
|
data_args.cutoff_len -= 1
|
||||||
|
|
||||||
pad_to_max = (
|
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||||
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
|
||||||
)
|
|
||||||
dpo_config = DPOConfig(
|
dpo_config = DPOConfig(
|
||||||
beta=finetuning_args.pref_beta,
|
beta=finetuning_args.pref_beta,
|
||||||
pref_loss=finetuning_args.pref_loss,
|
pref_loss=finetuning_args.pref_loss,
|
||||||
@ -289,4 +293,3 @@ def run_dpo(
|
|||||||
keys += ["eval_loss"]
|
keys += ["eval_loss"]
|
||||||
|
|
||||||
plot_loss(training_args.output_dir, keys=keys)
|
plot_loss(training_args.output_dir, keys=keys)
|
||||||
|
|
||||||
|
|||||||
@ -71,13 +71,17 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||||
if finetuning_args.stage == "pt":
|
if finetuning_args.stage == "pt":
|
||||||
from .mca import run_pt as run_pt_mca
|
from .mca import run_pt as run_pt_mca
|
||||||
|
|
||||||
run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif finetuning_args.stage == "sft":
|
elif finetuning_args.stage == "sft":
|
||||||
from .mca import run_sft as run_sft_mca
|
from .mca import run_sft as run_sft_mca
|
||||||
|
|
||||||
run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
else: # dpo
|
elif finetuning_args.stage == "dpo":
|
||||||
from .mca import run_dpo as run_dpo_mca
|
from .mca import run_dpo as run_dpo_mca
|
||||||
|
|
||||||
run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
|
||||||
elif finetuning_args.stage == "pt":
|
elif finetuning_args.stage == "pt":
|
||||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif finetuning_args.stage == "sft":
|
elif finetuning_args.stage == "sft":
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class KernelType(str, Enum):
|
|||||||
|
|
||||||
|
|
||||||
class DeviceType(str, Enum):
|
class DeviceType(str, Enum):
|
||||||
CPU = 'cpu'
|
CPU = "cpu"
|
||||||
CUDA = 'cuda'
|
CUDA = "cuda"
|
||||||
NPU = 'npu'
|
NPU = "npu"
|
||||||
XPU = 'xpu'
|
XPU = "xpu"
|
||||||
|
|||||||
@ -27,14 +27,11 @@ def _npu_swiglu_forward(self, hidden_state):
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
return self.down_proj(
|
return self.down_proj(
|
||||||
torch_npu.npu_swiglu(
|
torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
|
||||||
torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class NpuSwiGluKernel(MetaSwiGluKernel):
|
class NpuSwiGluKernel(MetaSwiGluKernel):
|
||||||
|
|
||||||
device = DeviceType.NPU
|
device = DeviceType.NPU
|
||||||
kernel = _npu_swiglu_forward
|
kernel = _npu_swiglu_forward
|
||||||
|
|
||||||
@ -43,7 +40,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
|||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
if not is_torch_npu_available():
|
if not is_torch_npu_available():
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -51,7 +48,6 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
|
|||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
# Match any module whose class name contains "RMSNorm"
|
# Match any module whose class name contains "RMSNorm"
|
||||||
if re.search(swiglu_pattern, module.__class__.__name__):
|
if re.search(swiglu_pattern, module.__class__.__name__):
|
||||||
|
|
||||||
# Bind function as an instance method to preserve `self` semantics
|
# Bind function as an instance method to preserve `self` semantics
|
||||||
# and replace the original forward
|
# and replace the original forward
|
||||||
module.forward = types.MethodType(cls.kernel, module)
|
module.forward = types.MethodType(cls.kernel, module)
|
||||||
|
|||||||
@ -21,10 +21,10 @@ from .constants import DeviceType, KernelType
|
|||||||
|
|
||||||
|
|
||||||
class KernelRegistry:
|
class KernelRegistry:
|
||||||
_instance: Optional['KernelRegistry'] = None
|
_instance: Optional["KernelRegistry"] = None
|
||||||
_initialized: bool = False
|
_initialized: bool = False
|
||||||
|
|
||||||
def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry':
|
def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
return cls._instance
|
return cls._instance
|
||||||
@ -36,10 +36,7 @@ class KernelRegistry:
|
|||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
|
||||||
kernel_type: KernelType,
|
|
||||||
device_type: DeviceType,
|
|
||||||
kernel_impl: Optional[Callable[..., Any]]
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a kernel implementation.
|
"""Register a kernel implementation.
|
||||||
|
|
||||||
@ -57,11 +54,7 @@ class KernelRegistry:
|
|||||||
self._registry[kernel_type][device_type] = kernel_impl
|
self._registry[kernel_type][device_type] = kernel_impl
|
||||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||||
|
|
||||||
def get_kernel(
|
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
|
||||||
self,
|
|
||||||
kernel_type: KernelType,
|
|
||||||
device_type: DeviceType
|
|
||||||
) -> Optional[Callable[..., Any]]:
|
|
||||||
return self._registry.get(kernel_type, {}).get(device_type)
|
return self._registry.get(kernel_type, {}).get(device_type)
|
||||||
|
|
||||||
|
|
||||||
@ -84,35 +77,30 @@ class MetaKernel(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class MetaFlashAttentionKernel(MetaKernel):
|
class MetaFlashAttentionKernel(MetaKernel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MetaRMSNormKernel(MetaKernel):
|
class MetaRMSNormKernel(MetaKernel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MetaSwiGluKernel(MetaKernel):
|
class MetaSwiGluKernel(MetaKernel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MetaRoPEKernel(MetaKernel):
|
class MetaRoPEKernel(MetaKernel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class MetaMoEKernel(MetaKernel):
|
class MetaMoEKernel(MetaKernel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
def apply(cls, model: HFModel, **kwargs) -> HFModel:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -130,7 +118,7 @@ def discover_kernels(model: HFModel) -> list[MetaKernel]:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel':
|
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
|
||||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||||
|
|
||||||
Corresponding replacement logic is maintained inside each kernel; the only
|
Corresponding replacement logic is maintained inside each kernel; the only
|
||||||
@ -145,4 +133,6 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFMo
|
|||||||
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
|
if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
|
||||||
return kernel.apply(model, **kwargs)
|
return kernel.apply(model, **kwargs)
|
||||||
|
|
||||||
raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.")
|
raise ValueError(
|
||||||
|
f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
|
||||||
|
)
|
||||||
|
|||||||
@ -65,7 +65,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
|
|||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
# Match any module whose class name contains "RMSNorm"
|
# Match any module whose class name contains "RMSNorm"
|
||||||
if re.search(rms_norm_pattern, module.__class__.__name__):
|
if re.search(rms_norm_pattern, module.__class__.__name__):
|
||||||
|
|
||||||
# Bind function as an instance method to preserve `self` semantics
|
# Bind function as an instance method to preserve `self` semantics
|
||||||
# and replace the original forward
|
# and replace the original forward
|
||||||
module.forward = types.MethodType(cls.kernel, module)
|
module.forward = types.MethodType(cls.kernel, module)
|
||||||
|
|||||||
@ -59,7 +59,7 @@ class NpuRoPEKernel(MetaRoPEKernel):
|
|||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||||
|
|
||||||
This function iterates through the model's modules to find attention layers,
|
This function iterates through the model's modules to find attention layers,
|
||||||
@ -96,7 +96,7 @@ class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
|
|||||||
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
KERNEL_REGISTRY.register(kernel_type, device_type, cls)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def apply(cls, model, **kwargs) -> 'HFModel':
|
def apply(cls, model, **kwargs) -> "HFModel":
|
||||||
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
"""Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
|
||||||
|
|
||||||
This function iterates through the model's modules to find attention layers,
|
This function iterates through the model's modules to find attention layers,
|
||||||
|
|||||||
@ -23,25 +23,25 @@ def get_available_accelerator():
|
|||||||
"""
|
"""
|
||||||
accelerator = torch.accelerator.current_accelerator()
|
accelerator = torch.accelerator.current_accelerator()
|
||||||
if accelerator is None:
|
if accelerator is None:
|
||||||
return torch.device('cpu')
|
return torch.device("cpu")
|
||||||
return accelerator
|
return accelerator
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def is_torch_npu_available():
|
def is_torch_npu_available():
|
||||||
return get_available_accelerator().type == 'npu'
|
return get_available_accelerator().type == "npu"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def is_torch_cuda_available():
|
def is_torch_cuda_available():
|
||||||
return get_available_accelerator().type == 'cuda'
|
return get_available_accelerator().type == "cuda"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def is_torch_xpu_available():
|
def is_torch_xpu_available():
|
||||||
return get_available_accelerator().type == 'xpu'
|
return get_available_accelerator().type == "xpu"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
def is_torch_mps_available():
|
def is_torch_mps_available():
|
||||||
return get_available_accelerator().type == 'mps'
|
return get_available_accelerator().type == "mps"
|
||||||
|
|||||||
@ -19,11 +19,10 @@ from transformers import AutoModelForCausalLM
|
|||||||
|
|
||||||
|
|
||||||
class TestKernelPlugin(unittest.TestCase):
|
class TestKernelPlugin(unittest.TestCase):
|
||||||
|
@patch("torch.accelerator.current_accelerator")
|
||||||
@patch('torch.accelerator.current_accelerator')
|
|
||||||
def test_apply_kernel(self, mock_get_accelerator):
|
def test_apply_kernel(self, mock_get_accelerator):
|
||||||
mock_device = MagicMock()
|
mock_device = MagicMock()
|
||||||
mock_device.type = 'npu'
|
mock_device.type = "npu"
|
||||||
mock_get_accelerator.return_value = mock_device
|
mock_get_accelerator.return_value = mock_device
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
|
||||||
@ -31,7 +30,6 @@ class TestKernelPlugin(unittest.TestCase):
|
|||||||
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
|
||||||
original_swiglu_forward = model.model.layers[0].mlp.forward
|
original_swiglu_forward = model.model.layers[0].mlp.forward
|
||||||
|
|
||||||
|
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
|
from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
|
||||||
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user