diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index eda7e5bc6..2c53d52c6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -67,8 +67,6 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python }} - cache: "pip" - cache-dependency-path: "**/requirements*.txt" - name: Install dependencies run: | diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index dcf989632..ca45c7d49 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -114,6 +114,7 @@ class AttentionFunction(str, Enum): DISABLED = "disabled" SDPA = "sdpa" FA2 = "fa2" + FA3 = "fa3" class EngineName(str, Enum): diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index e2349f40e..086ab5cf8 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -29,16 +29,19 @@ logger = logging.get_logger(__name__) def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None: + from transformers.utils import is_flash_attn_2_available + if getattr(config, "model_type", None) == "gpt_oss": from transformers.integrations.hub_kernels import load_and_register_kernel + flash_attn3_kernel = "kernels-community/vllm-flash-attn3" load_and_register_kernel(flash_attn3_kernel) setattr(config, "_attn_implementation", flash_attn3_kernel) setattr(config, "_attn_implementation_internal", flash_attn3_kernel) - model_args.flash_attn = flash_attn3_kernel + model_args.flash_attn = AttentionFunction.FA3 + + logger.info_rank0("Using FlashAttention-3 with attention sink for the gpt-oss model.") return - - from transformers.utils import is_flash_attn_2_available if getattr(config, "model_type", None) == "gemma2": if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index f7dcc85ec..a8f0e842e 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -78,8 +78,11 @@ def apply_liger_kernel( elif model_type == "qwen3_moe": from liger_kernel.transformers import apply_liger_kernel_to_qwen3_moe as apply_liger_kernel elif model_type == "gpt_oss": - # Install manually from https://github.com/Comet0322/Liger-Kernel - from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel + try: + from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss as apply_liger_kernel + except ImportError: + logger.warning_rank0("Please install liger-kernel from https://github.com/Comet0322/Liger-Kernel.") + return else: logger.warning_rank0("Current model does not support liger kernel.") return diff --git a/src/llamafactory/model/model_utils/moe.py b/src/llamafactory/model/model_utils/moe.py index 9e3662a27..250c38c70 100644 --- a/src/llamafactory/model/model_utils/moe.py +++ b/src/llamafactory/model/model_utils/moe.py @@ -82,6 +82,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: _set_z3_leaf_modules(model, [Glm4vMoeTextMoE]) + if model_type == "gpt_oss": + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP + + _set_z3_leaf_modules(model, [GptOssMLP]) + if model_type == "jamba": from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock @@ -129,13 +134,9 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None: if model_type in ("qwen3_omni_moe", "qwen3_omni_moe_thinker"): from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeThinkerTextSparseMoeBlock - + _set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock]) - - if model_type == "gpt_oss": - from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP - - _set_z3_leaf_modules(model, [GptOssMLP]) + def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: if not is_trainable or not model_args.moe_aux_loss_coef: