[misc] lint mca code (#9692)

This commit is contained in:
Kingsley
2025-12-29 11:44:38 +08:00
committed by GitHub
parent e97d0474fb
commit bb1ba31005
2 changed files with 24 additions and 27 deletions

View File

@@ -63,6 +63,7 @@ MCA_SUPPORTED_MODELS = {
"qwen2",
"qwen2_vl",
"qwen2_5_vl",
"qwen3_vl",
"qwen3",
"qwen3_moe",
"qwen3_next",

View File

@@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
from __future__ import annotations
import functools
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
from transformers import DataCollatorForSeq2Seq
from ...data import (
SFTDataCollatorWith4DAttentionMask,
@@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
from mcore_adapter.trainer import McaTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
if TYPE_CHECKING:
from transformers import DataCollatorForSeq2Seq, TrainerCallback
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
from transformers import TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
@@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any):
return wrapper
def _check_model_support(model_args: ModelArguments):
def _check_model_support(model_args: "ModelArguments"):
from transformers import AutoConfig as HfAutoConfig
config = HfAutoConfig.from_pretrained(
@@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments):
def run_pt(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
@@ -104,10 +103,7 @@ def run_pt(
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
from transformers import DataCollatorForSeq2Seq
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX,
@@ -142,11 +138,11 @@ def run_pt(
def run_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
# align packing flags
# TODO: FIX SequencePacking
@@ -166,7 +162,7 @@ def run_sft(
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
@@ -220,11 +216,11 @@ def run_sft(
def run_dpo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "McaSeq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]