mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-13 07:26:00 +08:00
[mca] support qwen3.5 (#10265)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -71,6 +71,7 @@ def convert(
|
|||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
expert_model_parallel_size: int = 1,
|
expert_model_parallel_size: int = 1,
|
||||||
virtual_pipeline_model_parallel_size: int | None = None,
|
virtual_pipeline_model_parallel_size: int | None = None,
|
||||||
|
moe_grouped_gemm: bool | None = None,
|
||||||
):
|
):
|
||||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||||
|
|
||||||
@@ -84,6 +85,10 @@ def convert(
|
|||||||
pipeline_model_parallel_size: Pipeline model parallel size
|
pipeline_model_parallel_size: Pipeline model parallel size
|
||||||
expert_model_parallel_size: Expert model parallel size
|
expert_model_parallel_size: Expert model parallel size
|
||||||
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||||
|
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
|
||||||
|
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
|
||||||
|
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
|
||||||
|
Must match the format used when saving the checkpoint.
|
||||||
"""
|
"""
|
||||||
if bf16 and fp16:
|
if bf16 and fp16:
|
||||||
raise ValueError("bf16 and fp16 cannot be both True.")
|
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||||
@@ -97,8 +102,9 @@ def convert(
|
|||||||
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||||
expert_model_parallel_size=expert_model_parallel_size,
|
expert_model_parallel_size=expert_model_parallel_size,
|
||||||
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||||
|
moe_grouped_gemm=moe_grouped_gemm,
|
||||||
|
transformer_impl="transformer_engine", # hard code here since we default using te for training
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_checkpoint_to_mca(
|
convert_checkpoint_to_mca(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
output_path,
|
output_path,
|
||||||
|
|||||||
@@ -69,6 +69,8 @@ MCA_SUPPORTED_MODELS = {
|
|||||||
"qwen3",
|
"qwen3",
|
||||||
"qwen3_moe",
|
"qwen3_moe",
|
||||||
"qwen3_next",
|
"qwen3_next",
|
||||||
|
"qwen3_5",
|
||||||
|
"qwen3_5_moe",
|
||||||
}
|
}
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora", "oft"]
|
METHODS = ["full", "freeze", "lora", "oft"]
|
||||||
|
|||||||
@@ -470,7 +470,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
|||||||
training_args.resume_from_checkpoint is None
|
training_args.resume_from_checkpoint is None
|
||||||
and training_args.do_train
|
and training_args.do_train
|
||||||
and os.path.isdir(training_args.output_dir)
|
and os.path.isdir(training_args.output_dir)
|
||||||
and not training_args.overwrite_output_dir
|
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
|
||||||
and can_resume_from_checkpoint
|
and can_resume_from_checkpoint
|
||||||
):
|
):
|
||||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
|
|||||||
if (
|
if (
|
||||||
args.should_save
|
args.should_save
|
||||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
and args.overwrite_output_dir
|
and getattr(args, "overwrite_output_dir", False)
|
||||||
):
|
):
|
||||||
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
||||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
|
import os
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
@@ -77,20 +79,25 @@ def _data_collator_wrapper(data_collator: Any):
|
|||||||
|
|
||||||
def _check_model_support(model_args: "ModelArguments"):
|
def _check_model_support(model_args: "ModelArguments"):
|
||||||
from transformers import AutoConfig as HfAutoConfig
|
from transformers import AutoConfig as HfAutoConfig
|
||||||
|
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
|
||||||
|
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
|
||||||
|
model_type = mca_config.get("hf_model_type", None)
|
||||||
|
else:
|
||||||
|
config = HfAutoConfig.from_pretrained(
|
||||||
|
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||||
|
)
|
||||||
|
model_type = config.model_type
|
||||||
|
|
||||||
config = HfAutoConfig.from_pretrained(
|
if model_type not in MCA_SUPPORTED_MODELS:
|
||||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
|
||||||
)
|
|
||||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Model {config.model_type} is not supported by mcore_adapter."
|
f"Model {model_type} is not supported by mcore_adapter."
|
||||||
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
||||||
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
||||||
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
|
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
|
||||||
return
|
return
|
||||||
|
|
||||||
params_to_freeze = []
|
params_to_freeze = []
|
||||||
|
|||||||
Reference in New Issue
Block a user