[mca] support qwen3.5 (#10265)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Kingsley
2026-03-10 10:55:16 +08:00
committed by GitHub
parent edeb953bc7
commit a3d44e3152
5 changed files with 24 additions and 9 deletions

View File

@@ -71,6 +71,7 @@ def convert(
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1, expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: int | None = None, virtual_pipeline_model_parallel_size: int | None = None,
moe_grouped_gemm: bool | None = None,
): ):
"""Convert checkpoint between MCA and HuggingFace formats. """Convert checkpoint between MCA and HuggingFace formats.
@@ -84,6 +85,10 @@ def convert(
pipeline_model_parallel_size: Pipeline model parallel size pipeline_model_parallel_size: Pipeline model parallel size
expert_model_parallel_size: Expert model parallel size expert_model_parallel_size: Expert model parallel size
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
Must match the format used when saving the checkpoint.
""" """
if bf16 and fp16: if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.") raise ValueError("bf16 and fp16 cannot be both True.")
@@ -97,8 +102,9 @@ def convert(
pipeline_model_parallel_size=pipeline_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size,
expert_model_parallel_size=expert_model_parallel_size, expert_model_parallel_size=expert_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
moe_grouped_gemm=moe_grouped_gemm,
transformer_impl="transformer_engine", # hard code here since we default using te for training
) )
convert_checkpoint_to_mca( convert_checkpoint_to_mca(
checkpoint_path, checkpoint_path,
output_path, output_path,

View File

@@ -69,6 +69,8 @@ MCA_SUPPORTED_MODELS = {
"qwen3", "qwen3",
"qwen3_moe", "qwen3_moe",
"qwen3_next", "qwen3_next",
"qwen3_5",
"qwen3_5_moe",
} }
METHODS = ["full", "freeze", "lora", "oft"] METHODS = ["full", "freeze", "lora", "oft"]

View File

@@ -470,7 +470,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
training_args.resume_from_checkpoint is None training_args.resume_from_checkpoint is None
and training_args.do_train and training_args.do_train
and os.path.isdir(training_args.output_dir) and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
and can_resume_from_checkpoint and can_resume_from_checkpoint
): ):
last_checkpoint = get_last_checkpoint(training_args.output_dir) last_checkpoint = get_last_checkpoint(training_args.output_dir)

View File

@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
if ( if (
args.should_save args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir and getattr(args, "overwrite_output_dir", False)
): ):
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.") logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))

View File

@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import functools import functools
import json
import os
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
@@ -77,20 +79,25 @@ def _data_collator_wrapper(data_collator: Any):
def _check_model_support(model_args: "ModelArguments"): def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig from transformers import AutoConfig as HfAutoConfig
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
model_type = mca_config.get("hf_model_type", None)
else:
config = HfAutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
model_type = config.model_type
config = HfAutoConfig.from_pretrained( if model_type not in MCA_SUPPORTED_MODELS:
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if config.model_type not in MCA_SUPPORTED_MODELS:
raise ValueError( raise ValueError(
f"Model {config.model_type} is not supported by mcore_adapter." f"Model {model_type} is not supported by mcore_adapter."
"You can try to upgrade mcore_adapter to the latest version for more supported models." "You can try to upgrade mcore_adapter to the latest version for more supported models."
) )
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"): def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments.""" """Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]: if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
return return
params_to_freeze = [] params_to_freeze = []