mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-31 11:10:35 +08:00
[misc] lint mca code (#9692)
This commit is contained in:
@@ -63,6 +63,7 @@ MCA_SUPPORTED_MODELS = {
|
|||||||
"qwen2",
|
"qwen2",
|
||||||
"qwen2_vl",
|
"qwen2_vl",
|
||||||
"qwen2_5_vl",
|
"qwen2_5_vl",
|
||||||
|
"qwen3_vl",
|
||||||
"qwen3",
|
"qwen3",
|
||||||
"qwen3_moe",
|
"qwen3_moe",
|
||||||
"qwen3_next",
|
"qwen3_next",
|
||||||
|
|||||||
@@ -11,14 +11,13 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from ...data import (
|
from ...data import (
|
||||||
SFTDataCollatorWith4DAttentionMask,
|
SFTDataCollatorWith4DAttentionMask,
|
||||||
@@ -44,11 +43,11 @@ from mcore_adapter.models import AutoConfig, AutoModel
|
|||||||
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
||||||
from mcore_adapter.trainer import McaTrainer
|
from mcore_adapter.trainer import McaTrainer
|
||||||
from mcore_adapter.trainer.dpo_config import DPOConfig
|
from mcore_adapter.trainer.dpo_config import DPOConfig
|
||||||
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import DataCollatorForSeq2Seq, TrainerCallback
|
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
@@ -76,7 +75,7 @@ def _data_collator_wrapper(data_collator: Any):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _check_model_support(model_args: ModelArguments):
|
def _check_model_support(model_args: "ModelArguments"):
|
||||||
from transformers import AutoConfig as HfAutoConfig
|
from transformers import AutoConfig as HfAutoConfig
|
||||||
|
|
||||||
config = HfAutoConfig.from_pretrained(
|
config = HfAutoConfig.from_pretrained(
|
||||||
@@ -87,11 +86,11 @@ def _check_model_support(model_args: ModelArguments):
|
|||||||
|
|
||||||
|
|
||||||
def run_pt(
|
def run_pt(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
training_args: McaSeq2SeqTrainingArguments,
|
training_args: "McaSeq2SeqTrainingArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: list[TrainerCallback] | None = None,
|
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
@@ -104,10 +103,7 @@ def run_pt(
|
|||||||
|
|
||||||
_check_model_support(model_args)
|
_check_model_support(model_args)
|
||||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
data_collator = DataCollatorForSeq2Seq(
|
||||||
from transformers import DataCollatorForSeq2Seq
|
|
||||||
|
|
||||||
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
|
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_to_multiple_of=8,
|
pad_to_multiple_of=8,
|
||||||
label_pad_token_id=IGNORE_INDEX,
|
label_pad_token_id=IGNORE_INDEX,
|
||||||
@@ -142,11 +138,11 @@ def run_pt(
|
|||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
training_args: McaSeq2SeqTrainingArguments,
|
training_args: "McaSeq2SeqTrainingArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: list[TrainerCallback] | None = None,
|
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
# align packing flags
|
# align packing flags
|
||||||
# TODO: FIX SequencePacking
|
# TODO: FIX SequencePacking
|
||||||
@@ -166,7 +162,7 @@ def run_sft(
|
|||||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
|
||||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"]:
|
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
|
||||||
params_to_freeze = []
|
params_to_freeze = []
|
||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||||
@@ -220,11 +216,11 @@ def run_sft(
|
|||||||
|
|
||||||
|
|
||||||
def run_dpo(
|
def run_dpo(
|
||||||
model_args: ModelArguments,
|
model_args: "ModelArguments",
|
||||||
data_args: DataArguments,
|
data_args: "DataArguments",
|
||||||
training_args: McaSeq2SeqTrainingArguments,
|
training_args: "McaSeq2SeqTrainingArguments",
|
||||||
finetuning_args: FinetuningArguments,
|
finetuning_args: "FinetuningArguments",
|
||||||
callbacks: list[TrainerCallback] | None = None,
|
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
|||||||
Reference in New Issue
Block a user