# Copyright 2025 the ROLL team and the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools from collections.abc import Sequence from copy import deepcopy from typing import TYPE_CHECKING, Any, Optional from transformers import DataCollatorForSeq2Seq from ...data import ( SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer, ) from ...data.collator import ( PairwiseDataCollatorWithPadding, ) from ...extras.constants import IGNORE_INDEX, MCA_SUPPORTED_MODELS from ...extras.logging import get_logger from ...extras.misc import calculate_tps from ...extras.packages import is_mcore_adapter_available from ...extras.ploting import plot_loss from ...model import load_tokenizer from ..callbacks import SaveProcessorCallback if not is_mcore_adapter_available(): raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.") from mcore_adapter.models import AutoConfig, AutoModel from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer from mcore_adapter.trainer import McaTrainer from mcore_adapter.trainer.dpo_config import DPOConfig if TYPE_CHECKING: from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments from transformers import TrainerCallback from ...hparams import DataArguments, FinetuningArguments, ModelArguments logger = get_logger(__name__) def _data_collator_wrapper(data_collator: Any): @functools.wraps(data_collator) def wrapper(features: Sequence[dict[str, Any]]): labels_key = [k for k in features[0].keys() if k.endswith("labels")] input_ids_key = [k for k in features[0].keys() if k.endswith("input_ids")] for feature in features: if len(labels_key) == 0: # pt feature["labels"] = deepcopy(feature["input_ids"])[1:] for k in labels_key: feature[k] = feature[k][1:] for k in input_ids_key: feature[k] = feature[k][:-1] for k in ["attention_mask", "position_ids"]: if k in feature: feature[k] = feature[k][:-1] return data_collator(features) return wrapper def _check_model_support(model_args: "ModelArguments"): from transformers import AutoConfig as HfAutoConfig config = HfAutoConfig.from_pretrained( model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code ) if config.model_type not in MCA_SUPPORTED_MODELS: raise ValueError( f"Model {config.model_type} is not supported by mcore_adapter." "You can try to upgrade mcore_adapter to the latest version for more supported models." ) def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"): """Freeze model parameters for qwen_vl series models based on finetuning arguments.""" if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]: return params_to_freeze = [] if finetuning_args.freeze_vision_tower: params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"]) if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]: params_to_freeze.extend(["vision_model.pos_embed"]) if finetuning_args.freeze_multi_modal_projector: params_to_freeze.extend(["multi_modal_projector"]) if finetuning_args.freeze_language_model: params_to_freeze.extend(["embedding", "decoder", "output_layer"]) if params_to_freeze: for name, p in model.named_parameters(): if any(name.startswith(k) for k in params_to_freeze): p.requires_grad_(False) def run_pt( model_args: "ModelArguments", data_args: "DataArguments", training_args: "McaSeq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] template = get_template_and_fix_tokenizer(tokenizer, data_args) # dataset needs +1 then cut back due to MCA shift logic data_args.cutoff_len += 1 dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module) data_args.cutoff_len -= 1 _check_model_support(model_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, pad_to_multiple_of=8, label_pad_token_id=IGNORE_INDEX, ) data_collator = _data_collator_wrapper(data_collator) trainer = McaTrainer( model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, **dataset_module, ) if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) if training_args.do_train: train_result = trainer.train(training_args.resume_from_checkpoint) trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: keys = ["loss"] if isinstance(dataset_module.get("eval_dataset"), dict): keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] else: keys += ["eval_loss"] plot_loss(training_args.output_dir, keys=keys) def run_sft( model_args: "ModelArguments", data_args: "DataArguments", training_args: "McaSeq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", callbacks: Optional[list["TrainerCallback"]] = None, ): # align packing flags # TODO: FIX SequencePacking data_args.neat_packing = training_args.sequence_packing = data_args.neat_packing or training_args.sequence_packing data_args.packing = data_args.neat_packing or data_args.packing tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] template = get_template_and_fix_tokenizer(tokenizer, data_args) # dataset needs +1 then cut back due to MCA shift logic data_args.cutoff_len += 1 dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) data_args.cutoff_len -= 1 _check_model_support(model_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) # optional freezing for qwen_vl series _freeze_model_parameters(model, finetuning_args) pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 data_collator = SFTDataCollatorWith4DAttentionMask( template=template, padding="max_length" if pad_to_max else "longest", max_length=data_args.cutoff_len if pad_to_max else None, pad_to_multiple_of=64, label_pad_token_id=IGNORE_INDEX, **tokenizer_module, ) data_collator = _data_collator_wrapper(data_collator) trainer = McaTrainer( model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, **dataset_module, ) if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) train_result = trainer.train(training_args.resume_from_checkpoint) trainer.save_model() trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: keys = ["loss"] if isinstance(dataset_module.get("eval_dataset"), dict): keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] else: keys += ["eval_loss"] plot_loss(training_args.output_dir, keys=keys) def run_dpo( model_args: "ModelArguments", data_args: "DataArguments", training_args: "McaSeq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", callbacks: Optional[list["TrainerCallback"]] = None, ): tokenizer_module = load_tokenizer(model_args) tokenizer = tokenizer_module["tokenizer"] template = get_template_and_fix_tokenizer(tokenizer, data_args) _check_model_support(model_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) _freeze_model_parameters(model, finetuning_args) if finetuning_args.use_ref_model: ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args) ref_model = AutoModel.from_config(ref_config) ref_model.load_state_dict(model.state_dict()) else: ref_model = None # dataset needs +1 then cut back due to MCA shift logic data_args.cutoff_len += 1 dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) data_args.cutoff_len -= 1 pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 dpo_config = DPOConfig( beta=finetuning_args.pref_beta, pref_loss=finetuning_args.pref_loss, label_smoothing=finetuning_args.dpo_label_smoothing, ) data_collator = PairwiseDataCollatorWithPadding( template=template, pad_to_multiple_of=64, padding="max_length" if pad_to_max else "longest", max_length=data_args.cutoff_len if pad_to_max else None, label_pad_token_id=IGNORE_INDEX, **tokenizer_module, ) data_collator = _data_collator_wrapper(data_collator) trainer = McaDPOTrainer( model=model, ref_model=ref_model, args=training_args, train_config=dpo_config, tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, **dataset_module, ) if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) train_result = trainer.train(training_args.resume_from_checkpoint) trainer.save_model() if finetuning_args.include_effective_tokens_per_second: train_result.metrics["effective_tokens_per_sec"] = calculate_tps( dataset_module["train_dataset"], train_result.metrics, stage="rm" ) trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() if trainer.is_world_process_zero() and finetuning_args.plot_loss: keys = ["loss", "rewards/accuracies"] if isinstance(dataset_module.get("eval_dataset"), dict): keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] else: keys += ["eval_loss"] plot_loss(training_args.output_dir, keys=keys)