[misc] lint mca code (#9692)

This commit is contained in:
Kingsley
2025-12-29 11:44:38 +08:00
committed by GitHub
parent e97d0474fb
commit bb1ba31005
2 changed files with 24 additions and 27 deletions

View File

@@ -63,6 +63,7 @@ MCA_SUPPORTED_MODELS = {
"qwen2", "qwen2",
"qwen2_vl", "qwen2_vl",
"qwen2_5_vl", "qwen2_5_vl",
"qwen3_vl",
"qwen3", "qwen3",
"qwen3_moe", "qwen3_moe",
"qwen3_next", "qwen3_next",

View File

@@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
from __future__ import annotations
import functools import functools
from collections.abc import Sequence from collections.abc import Sequence
from copy import deepcopy from copy import deepcopy
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Optional
from transformers import DataCollatorForSeq2Seq
from ...data import ( from ...data import (
SFTDataCollatorWith4DAttentionMask, SFTDataCollatorWith4DAttentionMask,
@@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
from mcore_adapter.trainer import McaTrainer from mcore_adapter.trainer import McaTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig from mcore_adapter.trainer.dpo_config import DPOConfig
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import DataCollatorForSeq2Seq, TrainerCallback from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments from ...hparams import DataArguments, FinetuningArguments, ModelArguments
@@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any):
return wrapper return wrapper
def _check_model_support(model_args: ModelArguments): def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig from transformers import AutoConfig as HfAutoConfig
config = HfAutoConfig.from_pretrained( config = HfAutoConfig.from_pretrained(
@@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments):
def run_pt( def run_pt(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: McaSeq2SeqTrainingArguments, training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: list[TrainerCallback] | None = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]
@@ -104,10 +103,7 @@ def run_pt(
_check_model_support(model_args) _check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
data_collator = DataCollatorForSeq2Seq(
from transformers import DataCollatorForSeq2Seq
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
tokenizer=tokenizer, tokenizer=tokenizer,
pad_to_multiple_of=8, pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX, label_pad_token_id=IGNORE_INDEX,
@@ -142,11 +138,11 @@ def run_pt(
def run_sft( def run_sft(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: McaSeq2SeqTrainingArguments, training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: list[TrainerCallback] | None = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
# align packing flags # align packing flags
# TODO: FIX SequencePacking # TODO: FIX SequencePacking
@@ -166,7 +162,7 @@ def run_sft(
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl # optional freezing for qwen2_vl, qwen2_5_vl
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]: if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
params_to_freeze = [] params_to_freeze = []
if finetuning_args.freeze_vision_tower: if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
@@ -220,11 +216,11 @@ def run_sft(
def run_dpo( def run_dpo(
model_args: ModelArguments, model_args: "ModelArguments",
data_args: DataArguments, data_args: "DataArguments",
training_args: McaSeq2SeqTrainingArguments, training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: FinetuningArguments, finetuning_args: "FinetuningArguments",
callbacks: list[TrainerCallback] | None = None, callbacks: Optional[list["TrainerCallback"]] = None,
): ):
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"] tokenizer = tokenizer_module["tokenizer"]