[mca] support qwen3.5 (#10265)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Kingsley
2026-03-10 10:55:16 +08:00
committed by GitHub
parent edeb953bc7
commit a3d44e3152
5 changed files with 24 additions and 9 deletions

View File

@@ -71,6 +71,7 @@ def convert(
pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: int | None = None,
moe_grouped_gemm: bool | None = None,
):
"""Convert checkpoint between MCA and HuggingFace formats.
@@ -84,6 +85,10 @@ def convert(
pipeline_model_parallel_size: Pipeline model parallel size
expert_model_parallel_size: Expert model parallel size
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
Must match the format used when saving the checkpoint.
"""
if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.")
@@ -97,8 +102,9 @@ def convert(
pipeline_model_parallel_size=pipeline_model_parallel_size,
expert_model_parallel_size=expert_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
moe_grouped_gemm=moe_grouped_gemm,
transformer_impl="transformer_engine", # hard code here since we default using te for training
)
convert_checkpoint_to_mca(
checkpoint_path,
output_path,

View File

@@ -69,6 +69,8 @@ MCA_SUPPORTED_MODELS = {
"qwen3",
"qwen3_moe",
"qwen3_next",
"qwen3_5",
"qwen3_5_moe",
}
METHODS = ["full", "freeze", "lora", "oft"]

View File

@@ -470,7 +470,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
training_args.resume_from_checkpoint is None
and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
and can_resume_from_checkpoint
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)

View File

@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
if (
args.should_save
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
and args.overwrite_output_dir
and getattr(args, "overwrite_output_dir", False)
):
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG))

View File

@@ -13,6 +13,8 @@
# limitations under the License.
import functools
import json
import os
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional
@@ -77,20 +79,25 @@ def _data_collator_wrapper(data_collator: Any):
def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
model_type = mca_config.get("hf_model_type", None)
else:
config = HfAutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
if config.model_type not in MCA_SUPPORTED_MODELS:
model_type = config.model_type
if model_type not in MCA_SUPPORTED_MODELS:
raise ValueError(
f"Model {config.model_type} is not supported by mcore_adapter."
f"Model {model_type} is not supported by mcore_adapter."
"You can try to upgrade mcore_adapter to the latest version for more supported models."
)
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
return
params_to_freeze = []