mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[misc] lint code (#9395)
This commit is contained in:
		
							parent
							
								
									215580c77d
								
							
						
					
					
						commit
						3ae15da9c0
					
				@ -137,7 +137,6 @@ def _load_single_dataset(
 | 
				
			|||||||
            cache_dir=model_args.cache_dir,
 | 
					            cache_dir=model_args.cache_dir,
 | 
				
			||||||
            token=model_args.hf_hub_token,
 | 
					            token=model_args.hf_hub_token,
 | 
				
			||||||
            num_proc=data_args.preprocessing_num_workers,
 | 
					            num_proc=data_args.preprocessing_num_workers,
 | 
				
			||||||
            trust_remote_code=model_args.trust_remote_code,
 | 
					 | 
				
			||||||
            streaming=data_args.streaming and dataset_attr.load_from != "file",
 | 
					            streaming=data_args.streaming and dataset_attr.load_from != "file",
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        if data_args.streaming and dataset_attr.load_from == "file":
 | 
					        if data_args.streaming and dataset_attr.load_from == "file":
 | 
				
			||||||
 | 
				
			|||||||
@ -70,7 +70,6 @@ if TYPE_CHECKING:
 | 
				
			|||||||
    from transformers.image_processing_utils import BaseImageProcessor
 | 
					    from transformers.image_processing_utils import BaseImageProcessor
 | 
				
			||||||
    from transformers.video_processing_utils import BaseVideoProcessor
 | 
					    from transformers.video_processing_utils import BaseVideoProcessor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    class EncodedImage(TypedDict):
 | 
					    class EncodedImage(TypedDict):
 | 
				
			||||||
        path: Optional[str]
 | 
					        path: Optional[str]
 | 
				
			||||||
        bytes: Optional[bytes]
 | 
					        bytes: Optional[bytes]
 | 
				
			||||||
 | 
				
			|||||||
@ -56,7 +56,18 @@ LAYERNORM_NAMES = {"norm", "ln"}
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
 | 
					LLAMABOARD_CONFIG = "llamaboard_config.yaml"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
MCA_SUPPORTED_MODELS = {"deepseek_v3", "llama", "mistral", "mixtral", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "qwen3_next"}
 | 
					MCA_SUPPORTED_MODELS = {
 | 
				
			||||||
 | 
					    "deepseek_v3",
 | 
				
			||||||
 | 
					    "llama",
 | 
				
			||||||
 | 
					    "mistral",
 | 
				
			||||||
 | 
					    "mixtral",
 | 
				
			||||||
 | 
					    "qwen2",
 | 
				
			||||||
 | 
					    "qwen2_vl",
 | 
				
			||||||
 | 
					    "qwen2_5_vl",
 | 
				
			||||||
 | 
					    "qwen3",
 | 
				
			||||||
 | 
					    "qwen3_moe",
 | 
				
			||||||
 | 
					    "qwen3_next",
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
METHODS = ["full", "freeze", "lora", "oft"]
 | 
					METHODS = ["full", "freeze", "lora", "oft"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -475,7 +475,12 @@ class FinetuningArguments(
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
    use_mca: bool = field(
 | 
					    use_mca: bool = field(
 | 
				
			||||||
        default=False,
 | 
					        default=False,
 | 
				
			||||||
        metadata={"help": "Whether or not to use MCA (Megatron Core Adapter) training. Controlled by USE_MCA environment variable."},
 | 
					        metadata={
 | 
				
			||||||
 | 
					            "help": (
 | 
				
			||||||
 | 
					                "Whether or not to use MCA (Megatron Core Adapter) training. "
 | 
				
			||||||
 | 
					                "Controlled by USE_MCA environment variable."
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    use_muon: bool = field(
 | 
					    use_muon: bool = field(
 | 
				
			||||||
        default=False,
 | 
					        default=False,
 | 
				
			||||||
 | 
				
			|||||||
@ -55,12 +55,16 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
 | 
					if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
 | 
				
			||||||
    from mcore_adapter import TrainingArguments as McaTrainingArguments
 | 
					    from mcore_adapter import TrainingArguments as McaTrainingArguments
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
 | 
					    _TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
 | 
				
			||||||
    _TRAIN_MCA_CLS = tuple[ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
 | 
					    _TRAIN_MCA_CLS = tuple[
 | 
				
			||||||
 | 
					        ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    _TRAIN_MCA_ARGS = []
 | 
					    _TRAIN_MCA_ARGS = []
 | 
				
			||||||
    _TRAIN_MCA_CLS = tuple()
 | 
					    _TRAIN_MCA_CLS = tuple()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
 | 
					def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
 | 
				
			||||||
    r"""Get arguments from the command line or a config file."""
 | 
					    r"""Get arguments from the command line or a config file."""
 | 
				
			||||||
    if args is not None:
 | 
					    if args is not None:
 | 
				
			||||||
 | 
				
			|||||||
@ -20,17 +20,18 @@ from transformers import Seq2SeqTrainingArguments
 | 
				
			|||||||
from transformers.training_args import _convert_str_dict
 | 
					from transformers.training_args import _convert_str_dict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..extras.misc import is_env_enabled, use_ray
 | 
					from ..extras.misc import is_env_enabled, use_ray
 | 
				
			||||||
 | 
					from ..extras.packages import is_mcore_adapter_available
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if is_env_enabled("USE_MCA"):
 | 
					if is_env_enabled("USE_MCA"):
 | 
				
			||||||
    try:
 | 
					    if not is_mcore_adapter_available():
 | 
				
			||||||
        from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
 | 
					 | 
				
			||||||
        BaseTrainingArguments = McaSeq2SeqTrainingArguments
 | 
					 | 
				
			||||||
    except ImportError:
 | 
					 | 
				
			||||||
        raise ImportError(
 | 
					        raise ImportError(
 | 
				
			||||||
            "mcore_adapter is required when USE_MCA=1.",
 | 
					            "mcore_adapter is required when USE_MCA=1. Please install `mcore_adapter` and its dependencies."
 | 
				
			||||||
            "Please install `mcore_adapter` and its dependencies."
 | 
					 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    BaseTrainingArguments = McaSeq2SeqTrainingArguments
 | 
				
			||||||
else:
 | 
					else:
 | 
				
			||||||
    BaseTrainingArguments = Seq2SeqTrainingArguments
 | 
					    BaseTrainingArguments = Seq2SeqTrainingArguments
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -54,8 +54,7 @@ def launch():
 | 
				
			|||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
 | 
					    command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
 | 
				
			||||||
    if is_env_enabled("USE_MCA"):
 | 
					    if is_env_enabled("USE_MCA"):  # force use torchrun
 | 
				
			||||||
    # force use torchrun
 | 
					 | 
				
			||||||
        os.environ["FORCE_TORCHRUN"] = "1"
 | 
					        os.environ["FORCE_TORCHRUN"] = "1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
 | 
					    if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
 | 
				
			||||||
 | 
				
			|||||||
@ -16,4 +16,3 @@ from .workflow import run_dpo, run_pt, run_sft
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = ["run_dpo", "run_pt", "run_sft"]
 | 
					__all__ = ["run_dpo", "run_pt", "run_sft"]
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -75,12 +75,17 @@ def _data_collator_wrapper(data_collator: Any):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return wrapper
 | 
					    return wrapper
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _check_model_support(model_args: ModelArguments):
 | 
					def _check_model_support(model_args: ModelArguments):
 | 
				
			||||||
    from transformers import AutoConfig as HfAutoConfig
 | 
					    from transformers import AutoConfig as HfAutoConfig
 | 
				
			||||||
    config = HfAutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
 | 
					
 | 
				
			||||||
 | 
					    config = HfAutoConfig.from_pretrained(
 | 
				
			||||||
 | 
					        model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
    if config.model_type not in MCA_SUPPORTED_MODELS:
 | 
					    if config.model_type not in MCA_SUPPORTED_MODELS:
 | 
				
			||||||
        raise ValueError(f"Model {config.model_type} is not supported by MCA.")
 | 
					        raise ValueError(f"Model {config.model_type} is not supported by MCA.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def run_pt(
 | 
					def run_pt(
 | 
				
			||||||
    model_args: ModelArguments,
 | 
					    model_args: ModelArguments,
 | 
				
			||||||
    data_args: DataArguments,
 | 
					    data_args: DataArguments,
 | 
				
			||||||
@ -161,22 +166,23 @@ def run_sft(
 | 
				
			|||||||
    model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
 | 
					    model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # optional freezing for qwen2_vl, qwen2_5_vl
 | 
					    # optional freezing for qwen2_vl, qwen2_5_vl
 | 
				
			||||||
    if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_vision_tower:
 | 
					    if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
 | 
				
			||||||
        for name, p in model.named_parameters():
 | 
					        params_to_freeze = []
 | 
				
			||||||
            if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]):
 | 
					        if finetuning_args.freeze_vision_tower:
 | 
				
			||||||
                p.requires_grad_(False)
 | 
					            params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
 | 
				
			||||||
    if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_multi_modal_projector:
 | 
					 | 
				
			||||||
        for name, p in model.named_parameters():
 | 
					 | 
				
			||||||
            if any(name.startswith(k) for k in ["multi_modal_projector"]):
 | 
					 | 
				
			||||||
                p.requires_grad_(False)
 | 
					 | 
				
			||||||
    if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_language_model:
 | 
					 | 
				
			||||||
        for name, p in model.named_parameters():
 | 
					 | 
				
			||||||
            if any(name.startswith(k) for k in ["embedding", "decoder", "output_layer"]):
 | 
					 | 
				
			||||||
                p.requires_grad_(False)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    pad_to_max = (
 | 
					        if finetuning_args.freeze_multi_modal_projector:
 | 
				
			||||||
        training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
 | 
					            params_to_freeze.extend(["multi_modal_projector"])
 | 
				
			||||||
    )
 | 
					
 | 
				
			||||||
 | 
					        if finetuning_args.freeze_language_model:
 | 
				
			||||||
 | 
					            params_to_freeze.extend(["embedding", "decoder", "output_layer"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if params_to_freeze:
 | 
				
			||||||
 | 
					            for name, p in model.named_parameters():
 | 
				
			||||||
 | 
					                if any(name.startswith(k) for k in params_to_freeze):
 | 
				
			||||||
 | 
					                    p.requires_grad_(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
 | 
				
			||||||
    data_collator = SFTDataCollatorWith4DAttentionMask(
 | 
					    data_collator = SFTDataCollatorWith4DAttentionMask(
 | 
				
			||||||
        template=template,
 | 
					        template=template,
 | 
				
			||||||
        padding="max_length" if pad_to_max else "longest",
 | 
					        padding="max_length" if pad_to_max else "longest",
 | 
				
			||||||
@ -239,9 +245,7 @@ def run_dpo(
 | 
				
			|||||||
    dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
 | 
					    dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
 | 
				
			||||||
    data_args.cutoff_len -= 1
 | 
					    data_args.cutoff_len -= 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    pad_to_max = (
 | 
					    pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
 | 
				
			||||||
        training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    dpo_config = DPOConfig(
 | 
					    dpo_config = DPOConfig(
 | 
				
			||||||
        beta=finetuning_args.pref_beta,
 | 
					        beta=finetuning_args.pref_beta,
 | 
				
			||||||
        pref_loss=finetuning_args.pref_loss,
 | 
					        pref_loss=finetuning_args.pref_loss,
 | 
				
			||||||
@ -289,4 +293,3 @@ def run_dpo(
 | 
				
			|||||||
            keys += ["eval_loss"]
 | 
					            keys += ["eval_loss"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        plot_loss(training_args.output_dir, keys=keys)
 | 
					        plot_loss(training_args.output_dir, keys=keys)
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
				
			|||||||
@ -71,13 +71,17 @@ def _training_function(config: dict[str, Any]) -> None:
 | 
				
			|||||||
            raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
 | 
					            raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
 | 
				
			||||||
        if finetuning_args.stage == "pt":
 | 
					        if finetuning_args.stage == "pt":
 | 
				
			||||||
            from .mca import run_pt as run_pt_mca
 | 
					            from .mca import run_pt as run_pt_mca
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
					            run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
				
			||||||
        elif finetuning_args.stage == "sft":
 | 
					        elif finetuning_args.stage == "sft":
 | 
				
			||||||
            from .mca import run_sft as run_sft_mca
 | 
					            from .mca import run_sft as run_sft_mca
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
					            run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
				
			||||||
        else:  # dpo
 | 
					        elif finetuning_args.stage == "dpo":
 | 
				
			||||||
            from .mca import run_dpo as run_dpo_mca
 | 
					            from .mca import run_dpo as run_dpo_mca
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
					            run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    elif finetuning_args.stage == "pt":
 | 
					    elif finetuning_args.stage == "pt":
 | 
				
			||||||
        run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
					        run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
 | 
				
			||||||
    elif finetuning_args.stage == "sft":
 | 
					    elif finetuning_args.stage == "sft":
 | 
				
			||||||
 | 
				
			|||||||
@ -24,7 +24,7 @@ class KernelType(str, Enum):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DeviceType(str, Enum):
 | 
					class DeviceType(str, Enum):
 | 
				
			||||||
    CPU = 'cpu'
 | 
					    CPU = "cpu"
 | 
				
			||||||
    CUDA = 'cuda'
 | 
					    CUDA = "cuda"
 | 
				
			||||||
    NPU = 'npu'
 | 
					    NPU = "npu"
 | 
				
			||||||
    XPU = 'xpu'
 | 
					    XPU = "xpu"
 | 
				
			||||||
 | 
				
			|||||||
@ -27,14 +27,11 @@ def _npu_swiglu_forward(self, hidden_state):
 | 
				
			|||||||
    import torch_npu
 | 
					    import torch_npu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return self.down_proj(
 | 
					    return self.down_proj(
 | 
				
			||||||
        torch_npu.npu_swiglu(
 | 
					        torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1)
 | 
				
			||||||
            torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class NpuSwiGluKernel(MetaSwiGluKernel):
 | 
					class NpuSwiGluKernel(MetaSwiGluKernel):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    device = DeviceType.NPU
 | 
					    device = DeviceType.NPU
 | 
				
			||||||
    kernel = _npu_swiglu_forward
 | 
					    kernel = _npu_swiglu_forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -43,7 +40,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
 | 
				
			|||||||
        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
					        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def apply(cls, model, **kwargs) -> 'HFModel':
 | 
					    def apply(cls, model, **kwargs) -> "HFModel":
 | 
				
			||||||
        if not is_torch_npu_available():
 | 
					        if not is_torch_npu_available():
 | 
				
			||||||
            return model
 | 
					            return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -51,7 +48,6 @@ class NpuSwiGluKernel(MetaSwiGluKernel):
 | 
				
			|||||||
        for name, module in model.named_modules():
 | 
					        for name, module in model.named_modules():
 | 
				
			||||||
            # Match any module whose class name contains "RMSNorm"
 | 
					            # Match any module whose class name contains "RMSNorm"
 | 
				
			||||||
            if re.search(swiglu_pattern, module.__class__.__name__):
 | 
					            if re.search(swiglu_pattern, module.__class__.__name__):
 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Bind function as an instance method to preserve `self` semantics
 | 
					                # Bind function as an instance method to preserve `self` semantics
 | 
				
			||||||
                # and replace the original forward
 | 
					                # and replace the original forward
 | 
				
			||||||
                module.forward = types.MethodType(cls.kernel, module)
 | 
					                module.forward = types.MethodType(cls.kernel, module)
 | 
				
			||||||
 | 
				
			|||||||
@ -21,10 +21,10 @@ from .constants import DeviceType, KernelType
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class KernelRegistry:
 | 
					class KernelRegistry:
 | 
				
			||||||
    _instance: Optional['KernelRegistry'] = None
 | 
					    _instance: Optional["KernelRegistry"] = None
 | 
				
			||||||
    _initialized: bool = False
 | 
					    _initialized: bool = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry':
 | 
					    def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry":
 | 
				
			||||||
        if cls._instance is None:
 | 
					        if cls._instance is None:
 | 
				
			||||||
            cls._instance = super().__new__(cls)
 | 
					            cls._instance = super().__new__(cls)
 | 
				
			||||||
        return cls._instance
 | 
					        return cls._instance
 | 
				
			||||||
@ -36,10 +36,7 @@ class KernelRegistry:
 | 
				
			|||||||
        self._initialized = True
 | 
					        self._initialized = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def register(
 | 
					    def register(
 | 
				
			||||||
        self,
 | 
					        self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
 | 
				
			||||||
        kernel_type: KernelType,
 | 
					 | 
				
			||||||
        device_type: DeviceType,
 | 
					 | 
				
			||||||
        kernel_impl: Optional[Callable[..., Any]]
 | 
					 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """Register a kernel implementation.
 | 
					        """Register a kernel implementation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -57,11 +54,7 @@ class KernelRegistry:
 | 
				
			|||||||
        self._registry[kernel_type][device_type] = kernel_impl
 | 
					        self._registry[kernel_type][device_type] = kernel_impl
 | 
				
			||||||
        print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
 | 
					        print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_kernel(
 | 
					    def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
 | 
				
			||||||
        self,
 | 
					 | 
				
			||||||
        kernel_type: KernelType,
 | 
					 | 
				
			||||||
        device_type: DeviceType
 | 
					 | 
				
			||||||
    ) -> Optional[Callable[..., Any]]:
 | 
					 | 
				
			||||||
        return self._registry.get(kernel_type, {}).get(device_type)
 | 
					        return self._registry.get(kernel_type, {}).get(device_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -84,35 +77,30 @@ class MetaKernel(ABC):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MetaFlashAttentionKernel(MetaKernel):
 | 
					class MetaFlashAttentionKernel(MetaKernel):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MetaRMSNormKernel(MetaKernel):
 | 
					class MetaRMSNormKernel(MetaKernel):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MetaSwiGluKernel(MetaKernel):
 | 
					class MetaSwiGluKernel(MetaKernel):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MetaRoPEKernel(MetaKernel):
 | 
					class MetaRoPEKernel(MetaKernel):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MetaMoEKernel(MetaKernel):
 | 
					class MetaMoEKernel(MetaKernel):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
					    def apply(cls, model: HFModel, **kwargs) -> HFModel:
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
@ -130,7 +118,7 @@ def discover_kernels(model: HFModel) -> list[MetaKernel]:
 | 
				
			|||||||
    return []
 | 
					    return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel':
 | 
					def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel":
 | 
				
			||||||
    """Call the MetaKernel's `apply` to perform the replacement.
 | 
					    """Call the MetaKernel's `apply` to perform the replacement.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Corresponding replacement logic is maintained inside each kernel; the only
 | 
					    Corresponding replacement logic is maintained inside each kernel; the only
 | 
				
			||||||
@ -145,4 +133,6 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFMo
 | 
				
			|||||||
    if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
 | 
					    if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type:
 | 
				
			||||||
        return kernel.apply(model, **kwargs)
 | 
					        return kernel.apply(model, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.")
 | 
					    raise ValueError(
 | 
				
			||||||
 | 
					        f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead."
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
				
			|||||||
@ -65,7 +65,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel):
 | 
				
			|||||||
        for name, module in model.named_modules():
 | 
					        for name, module in model.named_modules():
 | 
				
			||||||
            # Match any module whose class name contains "RMSNorm"
 | 
					            # Match any module whose class name contains "RMSNorm"
 | 
				
			||||||
            if re.search(rms_norm_pattern, module.__class__.__name__):
 | 
					            if re.search(rms_norm_pattern, module.__class__.__name__):
 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Bind function as an instance method to preserve `self` semantics
 | 
					                # Bind function as an instance method to preserve `self` semantics
 | 
				
			||||||
                # and replace the original forward
 | 
					                # and replace the original forward
 | 
				
			||||||
                module.forward = types.MethodType(cls.kernel, module)
 | 
					                module.forward = types.MethodType(cls.kernel, module)
 | 
				
			||||||
 | 
				
			|||||||
@ -59,7 +59,7 @@ class NpuRoPEKernel(MetaRoPEKernel):
 | 
				
			|||||||
        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
					        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def apply(cls, model, **kwargs) -> 'HFModel':
 | 
					    def apply(cls, model, **kwargs) -> "HFModel":
 | 
				
			||||||
        """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
 | 
					        """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        This function iterates through the model's modules to find attention layers,
 | 
					        This function iterates through the model's modules to find attention layers,
 | 
				
			||||||
@ -96,7 +96,7 @@ class NpuQwen2VLRoPEKernel(MetaRoPEKernel):
 | 
				
			|||||||
        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
					        KERNEL_REGISTRY.register(kernel_type, device_type, cls)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def apply(cls, model, **kwargs) -> 'HFModel':
 | 
					    def apply(cls, model, **kwargs) -> "HFModel":
 | 
				
			||||||
        """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
 | 
					        """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        This function iterates through the model's modules to find attention layers,
 | 
					        This function iterates through the model's modules to find attention layers,
 | 
				
			||||||
 | 
				
			|||||||
@ -23,25 +23,25 @@ def get_available_accelerator():
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
    accelerator = torch.accelerator.current_accelerator()
 | 
					    accelerator = torch.accelerator.current_accelerator()
 | 
				
			||||||
    if accelerator is None:
 | 
					    if accelerator is None:
 | 
				
			||||||
        return torch.device('cpu')
 | 
					        return torch.device("cpu")
 | 
				
			||||||
    return accelerator
 | 
					    return accelerator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@lru_cache
 | 
					@lru_cache
 | 
				
			||||||
def is_torch_npu_available():
 | 
					def is_torch_npu_available():
 | 
				
			||||||
    return get_available_accelerator().type == 'npu'
 | 
					    return get_available_accelerator().type == "npu"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@lru_cache
 | 
					@lru_cache
 | 
				
			||||||
def is_torch_cuda_available():
 | 
					def is_torch_cuda_available():
 | 
				
			||||||
    return get_available_accelerator().type == 'cuda'
 | 
					    return get_available_accelerator().type == "cuda"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@lru_cache
 | 
					@lru_cache
 | 
				
			||||||
def is_torch_xpu_available():
 | 
					def is_torch_xpu_available():
 | 
				
			||||||
    return get_available_accelerator().type == 'xpu'
 | 
					    return get_available_accelerator().type == "xpu"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@lru_cache
 | 
					@lru_cache
 | 
				
			||||||
def is_torch_mps_available():
 | 
					def is_torch_mps_available():
 | 
				
			||||||
    return get_available_accelerator().type == 'mps'
 | 
					    return get_available_accelerator().type == "mps"
 | 
				
			||||||
 | 
				
			|||||||
@ -19,11 +19,10 @@ from transformers import AutoModelForCausalLM
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestKernelPlugin(unittest.TestCase):
 | 
					class TestKernelPlugin(unittest.TestCase):
 | 
				
			||||||
 | 
					    @patch("torch.accelerator.current_accelerator")
 | 
				
			||||||
    @patch('torch.accelerator.current_accelerator')
 | 
					 | 
				
			||||||
    def test_apply_kernel(self, mock_get_accelerator):
 | 
					    def test_apply_kernel(self, mock_get_accelerator):
 | 
				
			||||||
        mock_device = MagicMock()
 | 
					        mock_device = MagicMock()
 | 
				
			||||||
        mock_device.type = 'npu'
 | 
					        mock_device.type = "npu"
 | 
				
			||||||
        mock_get_accelerator.return_value = mock_device
 | 
					        mock_get_accelerator.return_value = mock_device
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
 | 
					        model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5")
 | 
				
			||||||
@ -31,7 +30,6 @@ class TestKernelPlugin(unittest.TestCase):
 | 
				
			|||||||
        original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
 | 
					        original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward
 | 
				
			||||||
        original_swiglu_forward = model.model.layers[0].mlp.forward
 | 
					        original_swiglu_forward = model.model.layers[0].mlp.forward
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
        from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
 | 
					        from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu
 | 
				
			||||||
        from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
 | 
					        from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel
 | 
				
			||||||
        from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
 | 
					        from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user