mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-31 11:10:35 +08:00
[misc] lint mca code (#9692)
This commit is contained in:
@@ -63,6 +63,7 @@ MCA_SUPPORTED_MODELS = {
|
||||
"qwen2",
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
|
||||
@@ -11,14 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
from ...data import (
|
||||
SFTDataCollatorWith4DAttentionMask,
|
||||
@@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel
|
||||
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
||||
from mcore_adapter.trainer import McaTrainer
|
||||
from mcore_adapter.trainer.dpo_config import DPOConfig
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import DataCollatorForSeq2Seq, TrainerCallback
|
||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||
from transformers import TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
|
||||
@@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any):
|
||||
return wrapper
|
||||
|
||||
|
||||
def _check_model_support(model_args: ModelArguments):
|
||||
def _check_model_support(model_args: "ModelArguments"):
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
@@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments):
|
||||
|
||||
|
||||
def run_pt(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
@@ -104,10 +103,7 @@ def run_pt(
|
||||
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=8,
|
||||
label_pad_token_id=IGNORE_INDEX,
|
||||
@@ -142,11 +138,11 @@ def run_pt(
|
||||
|
||||
|
||||
def run_sft(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
# align packing flags
|
||||
# TODO: FIX SequencePacking
|
||||
@@ -166,7 +162,7 @@ def run_sft(
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
@@ -220,11 +216,11 @@ def run_sft(
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: ModelArguments,
|
||||
data_args: DataArguments,
|
||||
training_args: McaSeq2SeqTrainingArguments,
|
||||
finetuning_args: FinetuningArguments,
|
||||
callbacks: list[TrainerCallback] | None = None,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "McaSeq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
|
||||
Reference in New Issue
Block a user