diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index e78c8b908..00193c276 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -63,6 +63,7 @@ MCA_SUPPORTED_MODELS = { "qwen2", "qwen2_vl", "qwen2_5_vl", + "qwen3_vl", "qwen3", "qwen3_moe", "qwen3_next", diff --git a/src/llamafactory/train/mca/workflow.py b/src/llamafactory/train/mca/workflow.py index 4684e827d..affa2efc8 100644 --- a/src/llamafactory/train/mca/workflow.py +++ b/src/llamafactory/train/mca/workflow.py @@ -11,14 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style.""" - -from __future__ import annotations import functools from collections.abc import Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional + +from transformers import DataCollatorForSeq2Seq from ...data import ( SFTDataCollatorWith4DAttentionMask, @@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer from mcore_adapter.trainer import McaTrainer from mcore_adapter.trainer.dpo_config import DPOConfig -from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments if TYPE_CHECKING: - from transformers import DataCollatorForSeq2Seq, TrainerCallback + from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments + from transformers import TrainerCallback from ...hparams import DataArguments, FinetuningArguments, ModelArguments @@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any): return wrapper -def _check_model_support(model_args: ModelArguments): +def _check_model_support(model_args: "ModelArguments"): from transformers import AutoConfig as HfAutoConfig config = HfAutoConfig.from_pretrained( @@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments): def run_pt( - model_args: ModelArguments, - data_args: DataArguments, - training_args: McaSeq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: list[TrainerCallback] | None = None, + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "McaSeq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] @@ -104,10 +103,7 @@ def run_pt( _check_model_support(model_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) - - from transformers import DataCollatorForSeq2Seq - - data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq( + data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, pad_to_multiple_of=8, label_pad_token_id=IGNORE_INDEX, @@ -142,11 +138,11 @@ def run_pt( def run_sft( - model_args: ModelArguments, - data_args: DataArguments, - training_args: McaSeq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: list[TrainerCallback] | None = None, + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "McaSeq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[list["TrainerCallback"]] = None, ): # align packing flags # TODO: FIX SequencePacking @@ -166,7 +162,7 @@ def run_sft( model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) # optional freezing for qwen2_vl, qwen2_5_vl - if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]: + if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]: params_to_freeze = [] if finetuning_args.freeze_vision_tower: params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) @@ -220,11 +216,11 @@ def run_sft( def run_dpo( - model_args: ModelArguments, - data_args: DataArguments, - training_args: McaSeq2SeqTrainingArguments, - finetuning_args: FinetuningArguments, - callbacks: list[TrainerCallback] | None = None, + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "McaSeq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"]