diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index b5adc139..cbb13455 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -137,7 +137,6 @@ def _load_single_dataset( cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, num_proc=data_args.preprocessing_num_workers, - trust_remote_code=model_args.trust_remote_code, streaming=data_args.streaming and dataset_attr.load_from != "file", ) if data_args.streaming and dataset_attr.load_from == "file": diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 691a4cf3..0f6f0973 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -70,7 +70,6 @@ if TYPE_CHECKING: from transformers.image_processing_utils import BaseImageProcessor from transformers.video_processing_utils import BaseVideoProcessor - class EncodedImage(TypedDict): path: Optional[str] bytes: Optional[bytes] diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index e84e088b..83166589 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -56,7 +56,18 @@ LAYERNORM_NAMES = {"norm", "ln"} LLAMABOARD_CONFIG = "llamaboard_config.yaml" -MCA_SUPPORTED_MODELS = {"deepseek_v3", "llama", "mistral", "mixtral", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "qwen3_next"} +MCA_SUPPORTED_MODELS = { + "deepseek_v3", + "llama", + "mistral", + "mixtral", + "qwen2", + "qwen2_vl", + "qwen2_5_vl", + "qwen3", + "qwen3_moe", + "qwen3_next", +} METHODS = ["full", "freeze", "lora", "oft"] diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 6a3aaaff..ef690d7b 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -475,7 +475,12 @@ class FinetuningArguments( ) use_mca: bool = field( default=False, - metadata={"help": "Whether or not to use MCA (Megatron Core Adapter) training. Controlled by USE_MCA environment variable."}, + metadata={ + "help": ( + "Whether or not to use MCA (Megatron Core Adapter) training. " + "Controlled by USE_MCA environment variable." + ) + }, ) use_muon: bool = field( default=False, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index eca60407..e830d0cc 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -55,12 +55,16 @@ _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, Finetuning if is_mcore_adapter_available() and is_env_enabled("USE_MCA"): from mcore_adapter import TrainingArguments as McaTrainingArguments + _TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments] - _TRAIN_MCA_CLS = tuple[ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments] + _TRAIN_MCA_CLS = tuple[ + ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments + ] else: _TRAIN_MCA_ARGS = [] _TRAIN_MCA_CLS = tuple() + def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]: r"""Get arguments from the command line or a config file.""" if args is not None: diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index 4c83a93c..46b40a2d 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -20,17 +20,18 @@ from transformers import Seq2SeqTrainingArguments from transformers.training_args import _convert_str_dict from ..extras.misc import is_env_enabled, use_ray +from ..extras.packages import is_mcore_adapter_available if is_env_enabled("USE_MCA"): - try: - from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments - BaseTrainingArguments = McaSeq2SeqTrainingArguments - except ImportError: + if not is_mcore_adapter_available(): raise ImportError( - "mcore_adapter is required when USE_MCA=1.", - "Please install `mcore_adapter` and its dependencies." + "mcore_adapter is required when USE_MCA=1. Please install `mcore_adapter` and its dependencies." ) + + from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments + + BaseTrainingArguments = McaSeq2SeqTrainingArguments else: BaseTrainingArguments = Seq2SeqTrainingArguments diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py index 4d0f10c2..99f2ea3e 100644 --- a/src/llamafactory/launcher.py +++ b/src/llamafactory/launcher.py @@ -54,8 +54,7 @@ def launch(): ) command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" - if is_env_enabled("USE_MCA"): - # force use torchrun + if is_env_enabled("USE_MCA"): # force use torchrun os.environ["FORCE_TORCHRUN"] = "1" if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): diff --git a/src/llamafactory/train/mca/__init__.py b/src/llamafactory/train/mca/__init__.py index 2a68229e..2b3fb6eb 100644 --- a/src/llamafactory/train/mca/__init__.py +++ b/src/llamafactory/train/mca/__init__.py @@ -16,4 +16,3 @@ from .workflow import run_dpo, run_pt, run_sft __all__ = ["run_dpo", "run_pt", "run_sft"] - diff --git a/src/llamafactory/train/mca/workflow.py b/src/llamafactory/train/mca/workflow.py index 2aa523db..4684e827 100644 --- a/src/llamafactory/train/mca/workflow.py +++ b/src/llamafactory/train/mca/workflow.py @@ -75,12 +75,17 @@ def _data_collator_wrapper(data_collator: Any): return wrapper + def _check_model_support(model_args: ModelArguments): from transformers import AutoConfig as HfAutoConfig - config = HfAutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code) + + config = HfAutoConfig.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) if config.model_type not in MCA_SUPPORTED_MODELS: raise ValueError(f"Model {config.model_type} is not supported by MCA.") + def run_pt( model_args: ModelArguments, data_args: DataArguments, @@ -161,22 +166,23 @@ def run_sft( model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) # optional freezing for qwen2_vl, qwen2_5_vl - if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_vision_tower: - for name, p in model.named_parameters(): - if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]): - p.requires_grad_(False) - if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_multi_modal_projector: - for name, p in model.named_parameters(): - if any(name.startswith(k) for k in ["multi_modal_projector"]): - p.requires_grad_(False) - if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_language_model: - for name, p in model.named_parameters(): - if any(name.startswith(k) for k in ["embedding", "decoder", "output_layer"]): - p.requires_grad_(False) + if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]: + params_to_freeze = [] + if finetuning_args.freeze_vision_tower: + params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) - pad_to_max = ( - training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 - ) + if finetuning_args.freeze_multi_modal_projector: + params_to_freeze.extend(["multi_modal_projector"]) + + if finetuning_args.freeze_language_model: + params_to_freeze.extend(["embedding", "decoder", "output_layer"]) + + if params_to_freeze: + for name, p in model.named_parameters(): + if any(name.startswith(k) for k in params_to_freeze): + p.requires_grad_(False) + + pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 data_collator = SFTDataCollatorWith4DAttentionMask( template=template, padding="max_length" if pad_to_max else "longest", @@ -239,9 +245,7 @@ def run_dpo( dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) data_args.cutoff_len -= 1 - pad_to_max = ( - training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 - ) + pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 dpo_config = DPOConfig( beta=finetuning_args.pref_beta, pref_loss=finetuning_args.pref_loss, @@ -289,4 +293,3 @@ def run_dpo( keys += ["eval_loss"] plot_loss(training_args.output_dir, keys=keys) - diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index a3538ad3..f8b84107 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -71,13 +71,17 @@ def _training_function(config: dict[str, Any]) -> None: raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.") if finetuning_args.stage == "pt": from .mca import run_pt as run_pt_mca + run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "sft": from .mca import run_sft as run_sft_mca + run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks) - else: # dpo + elif finetuning_args.stage == "dpo": from .mca import run_dpo as run_dpo_mca + run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "sft": diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py index 063ebb44..1ad988bf 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/constants.py @@ -24,7 +24,7 @@ class KernelType(str, Enum): class DeviceType(str, Enum): - CPU = 'cpu' - CUDA = 'cuda' - NPU = 'npu' - XPU = 'xpu' + CPU = "cpu" + CUDA = "cuda" + NPU = "npu" + XPU = "xpu" diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py index be331dec..702d27bc 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/mlp/npu_swiglu.py @@ -27,14 +27,11 @@ def _npu_swiglu_forward(self, hidden_state): import torch_npu return self.down_proj( - torch_npu.npu_swiglu( - torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1 - ) + torch_npu.npu_swiglu(torch.cat((self.gate_proj(hidden_state), self.up_proj(hidden_state)), dim=-1), dim=-1) ) class NpuSwiGluKernel(MetaSwiGluKernel): - device = DeviceType.NPU kernel = _npu_swiglu_forward @@ -43,7 +40,7 @@ class NpuSwiGluKernel(MetaSwiGluKernel): KERNEL_REGISTRY.register(kernel_type, device_type, cls) @classmethod - def apply(cls, model, **kwargs) -> 'HFModel': + def apply(cls, model, **kwargs) -> "HFModel": if not is_torch_npu_available(): return model @@ -51,7 +48,6 @@ class NpuSwiGluKernel(MetaSwiGluKernel): for name, module in model.named_modules(): # Match any module whose class name contains "RMSNorm" if re.search(swiglu_pattern, module.__class__.__name__): - # Bind function as an instance method to preserve `self` semantics # and replace the original forward module.forward = types.MethodType(cls.kernel, module) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py index 33597c48..45e14a09 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/registry.py @@ -21,10 +21,10 @@ from .constants import DeviceType, KernelType class KernelRegistry: - _instance: Optional['KernelRegistry'] = None + _instance: Optional["KernelRegistry"] = None _initialized: bool = False - def __new__(cls, *args: Any, **kwargs: Any) -> 'KernelRegistry': + def __new__(cls, *args: Any, **kwargs: Any) -> "KernelRegistry": if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance @@ -36,10 +36,7 @@ class KernelRegistry: self._initialized = True def register( - self, - kernel_type: KernelType, - device_type: DeviceType, - kernel_impl: Optional[Callable[..., Any]] + self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]] ) -> None: """Register a kernel implementation. @@ -57,11 +54,7 @@ class KernelRegistry: self._registry[kernel_type][device_type] = kernel_impl print(f"Registered kernel {kernel_type.name} for device {device_type.name}.") - def get_kernel( - self, - kernel_type: KernelType, - device_type: DeviceType - ) -> Optional[Callable[..., Any]]: + def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]: return self._registry.get(kernel_type, {}).get(device_type) @@ -84,35 +77,30 @@ class MetaKernel(ABC): class MetaFlashAttentionKernel(MetaKernel): - @classmethod def apply(cls, model: HFModel, **kwargs) -> HFModel: raise NotImplementedError class MetaRMSNormKernel(MetaKernel): - @classmethod def apply(cls, model: HFModel, **kwargs) -> HFModel: raise NotImplementedError class MetaSwiGluKernel(MetaKernel): - @classmethod def apply(cls, model: HFModel, **kwargs) -> HFModel: raise NotImplementedError class MetaRoPEKernel(MetaKernel): - @classmethod def apply(cls, model: HFModel, **kwargs) -> HFModel: raise NotImplementedError class MetaMoEKernel(MetaKernel): - @classmethod def apply(cls, model: HFModel, **kwargs) -> HFModel: raise NotImplementedError @@ -130,7 +118,7 @@ def discover_kernels(model: HFModel) -> list[MetaKernel]: return [] -def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFModel': +def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> "HFModel": """Call the MetaKernel's `apply` to perform the replacement. Corresponding replacement logic is maintained inside each kernel; the only @@ -145,4 +133,6 @@ def apply_kernel(model: HFModel, kernel: type[MetaKernel], /, **kwargs) -> 'HFMo if issubclass(kernel, MetaKernel) and kernel.device == get_available_accelerator().type: return kernel.apply(model, **kwargs) - raise ValueError(f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead.") + raise ValueError( + f"{kernel} must be a MetaKernel instance, or the kernel don't match the device type. got {kernel.device} and {get_available_accelerator().type} instead." + ) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py index 018758ee..d6f032b9 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rms_norm/npu_rms_norm.py @@ -65,7 +65,6 @@ class NpuRMSNormKernel(MetaRMSNormKernel): for name, module in model.named_modules(): # Match any module whose class name contains "RMSNorm" if re.search(rms_norm_pattern, module.__class__.__name__): - # Bind function as an instance method to preserve `self` semantics # and replace the original forward module.forward = types.MethodType(cls.kernel, module) diff --git a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py index a1d41dd4..8cb40575 100644 --- a/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py +++ b/src/llamafactory/v1/plugins/model_plugins/kernels/rope/npu_rope.py @@ -59,7 +59,7 @@ class NpuRoPEKernel(MetaRoPEKernel): KERNEL_REGISTRY.register(kernel_type, device_type, cls) @classmethod - def apply(cls, model, **kwargs) -> 'HFModel': + def apply(cls, model, **kwargs) -> "HFModel": """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. This function iterates through the model's modules to find attention layers, @@ -96,7 +96,7 @@ class NpuQwen2VLRoPEKernel(MetaRoPEKernel): KERNEL_REGISTRY.register(kernel_type, device_type, cls) @classmethod - def apply(cls, model, **kwargs) -> 'HFModel': + def apply(cls, model, **kwargs) -> "HFModel": """Apply RoPE acceleration by monkey-patching `apply_rotary_pos_emb`. This function iterates through the model's modules to find attention layers, diff --git a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py index 4f090e7d..e7a9bf30 100644 --- a/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py +++ b/src/llamafactory/v1/plugins/trainer_plugins/distributed/accelerate.py @@ -23,25 +23,25 @@ def get_available_accelerator(): """ accelerator = torch.accelerator.current_accelerator() if accelerator is None: - return torch.device('cpu') + return torch.device("cpu") return accelerator @lru_cache def is_torch_npu_available(): - return get_available_accelerator().type == 'npu' + return get_available_accelerator().type == "npu" @lru_cache def is_torch_cuda_available(): - return get_available_accelerator().type == 'cuda' + return get_available_accelerator().type == "cuda" @lru_cache def is_torch_xpu_available(): - return get_available_accelerator().type == 'xpu' + return get_available_accelerator().type == "xpu" @lru_cache def is_torch_mps_available(): - return get_available_accelerator().type == 'mps' + return get_available_accelerator().type == "mps" diff --git a/tests_v1/plugins/model_plugins/test_kernel_plugin.py b/tests_v1/plugins/model_plugins/test_kernel_plugin.py index a89b8bd7..2830d8c5 100644 --- a/tests_v1/plugins/model_plugins/test_kernel_plugin.py +++ b/tests_v1/plugins/model_plugins/test_kernel_plugin.py @@ -19,11 +19,10 @@ from transformers import AutoModelForCausalLM class TestKernelPlugin(unittest.TestCase): - - @patch('torch.accelerator.current_accelerator') + @patch("torch.accelerator.current_accelerator") def test_apply_kernel(self, mock_get_accelerator): mock_device = MagicMock() - mock_device.type = 'npu' + mock_device.type = "npu" mock_get_accelerator.return_value = mock_device model = AutoModelForCausalLM.from_pretrained("llamafactory/tiny-random-qwen2.5") @@ -31,7 +30,6 @@ class TestKernelPlugin(unittest.TestCase): original_rmsnorm_forward = model.model.layers[0].input_layernorm.forward original_swiglu_forward = model.model.layers[0].mlp.forward - from llamafactory.v1.plugins.model_plugins.kernels.mlp import npu_swiglu from llamafactory.v1.plugins.model_plugins.kernels.registry import apply_kernel from llamafactory.v1.plugins.model_plugins.kernels.rms_norm import npu_rms_norm