diff --git a/docker/docker-cuda/Dockerfile.megatron b/docker/docker-cuda/Dockerfile.megatron new file mode 100644 index 00000000..7b603e3e --- /dev/null +++ b/docker/docker-cuda/Dockerfile.megatron @@ -0,0 +1,77 @@ +# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10) +# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html +FROM nvcr.io/nvidia/pytorch:24.05-py3 + +ENV DEBIAN_FRONTEND=noninteractive +ENV PIP_ROOT_USER_ACTION=ignore +ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/ +ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com +ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ + +RUN pip install --upgrade pip setuptools wheel --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} + +RUN pip uninstall -y torch torchvision torch-tensorrt \ + flash_attn transformer-engine \ + cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf + +RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 + +RUN pip uninstall -y opencv opencv-python opencv-python-headless && \ + rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \ + pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} + +RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \ + transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \ + --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} + +RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl + +# RUN pip install vllm==0.8.4 \ +# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR} + +WORKDIR /build + +ARG apex_url=git+https://github.com/NVIDIA/apex.git@25.04 +RUN pip uninstall -y apex && \ + MAX_JOBS=32 NINJA_FLAGS="-j32" NVCC_APPEND_FLAGS="--threads 32" \ + pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \ + --config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url} + +RUN rm -rf /build +WORKDIR /workspace + +RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \ + { \ + echo "deb ${APT_MIRROR} jammy main restricted universe multiverse"; \ + echo "deb ${APT_MIRROR} jammy-security main restricted universe multiverse"; \ + echo "deb ${APT_MIRROR} jammy-updates main restricted universe multiverse"; \ + echo "deb ${APT_MIRROR} jammy-backports main restricted universe multiverse"; \ + } > /etc/apt/sources.list + +RUN apt-get update && apt-get install -y zip + +RUN apt-get install -y openjdk-21-jdk +ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64 + +# pip install LLaMA-Factory +WORKDIR /app + +COPY requirements.txt /app/ +RUN pip install --no-cache-dir -r requirements.txt + +RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter" + +COPY . /app/ +RUN pip install -e ".[metrics]" --no-build-isolation + +# Expose port 7860 for LLaMA Board +ENV GRADIO_SERVER_PORT=7860 +EXPOSE 7860 + +# Expose port 8000 for API service +ENV API_PORT=8000 +EXPOSE 8000 + +# unset proxy +ENV http_proxy= +ENV https_proxy= diff --git a/examples/megatron/qwen2_vl_full.yaml b/examples/megatron/qwen2_vl_full.yaml new file mode 100644 index 00000000..5e22e3ec --- /dev/null +++ b/examples/megatron/qwen2_vl_full.yaml @@ -0,0 +1,29 @@ +model_name_or_path: Qwen/Qwen2-VL-7B-Instruct +image_max_pixels: 262144 +video_max_pixels: 16384 + +do_train: true +stage: sft +finetuning_type: full # only support full for now +dataset: llava_1k_en +preprocessing_num_workers: 8 +cutoff_len: 4096 +template: qwen2_vl + +output_dir: saves/mca/qwen2_vl_full +per_device_train_batch_size: 1 +gradient_accumulation_steps: 2 +num_train_epochs: 2 +learning_rate: 2e-5 +logging_steps: 1 +save_steps: 100 +lr_scheduler_type: cosine +bf16: true + +# mcore speed up +tensor_model_parallel_size: 4 +sequence_parallel: true +pipeline_model_parallel_size: 2 +bias_activation_fusion: true +apply_rope_fusion: true +use_distributed_optimizer: true diff --git a/examples/megatron/qwen3_moe_full.yaml b/examples/megatron/qwen3_moe_full.yaml new file mode 100644 index 00000000..3b62bf91 --- /dev/null +++ b/examples/megatron/qwen3_moe_full.yaml @@ -0,0 +1,35 @@ +model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507 + +# GPU memory: 8 * 78GB +do_train: true +stage: sft +finetuning_type: full # only support full for now +dataset: alpaca_en_demo +preprocessing_num_workers: 8 +cutoff_len: 4096 +template: qwen3_nothink + +# global batchsize = (8 // 2 // 4) * 8 = 8 +output_dir: saves/mca/qwen3_moe_full +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +num_train_epochs: 2 +learning_rate: 3e-6 +logging_steps: 1 +save_steps: 100 +lr_scheduler_type: constant +bf16: true + +# mcore speed up +tensor_model_parallel_size: 1 +sequence_parallel: false +pipeline_model_parallel_size: 4 +bias_activation_fusion: true +apply_rope_fusion: true +use_distributed_optimizer: true +overlap_param_gather: true +overlap_grad_reduce: true +moe_grouped_gemm: true +moe_token_dispatcher_type: alltoall +expert_model_parallel_size: 2 +recompute_granularity: full diff --git a/scripts/megatron_merge.py b/scripts/megatron_merge.py new file mode 100644 index 00000000..47ad98f0 --- /dev/null +++ b/scripts/megatron_merge.py @@ -0,0 +1,125 @@ +# Copyright 2025 the ROLL team and the LlamaFactory team. +# +# This code is modified from the ROLL library. +# https://github.com/alibaba/ROLL/blob/main/mcore_adapter/tools/convert.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional + +import fire +import torch +from mcore_adapter.models.converter.post_converter import convert_checkpoint_to_hf, convert_checkpoint_to_mca +from mcore_adapter.training_args import DistributingParallelArguments +from mcore_adapter.utils import get_logger +from transformers import AutoConfig + + +logger = get_logger(__name__) + + +def convert_mca_to_hf( + checkpoint_path: str, + output_path: str = "./output", + bf16: bool = False, + fp16: bool = False, + convert_model_max_length: Optional[int] = None, +): + """Convert megatron checkpoint to HuggingFace format. + + Args: + checkpoint_path: Path to the checkpoint to convert + output_path: Path to save the converted checkpoint + bf16: Use bfloat16 precision + fp16: Use float16 precision + convert_model_max_length: Change the model_max_length in hf config.json + """ + if bf16 and fp16: + raise ValueError("bf16 and fp16 cannot be both True.") + + torch_dtype = None + if bf16: + torch_dtype = torch.bfloat16 + elif fp16: + torch_dtype = torch.float16 + + convert_checkpoint_to_hf(checkpoint_path, output_path, torch_dtype=torch_dtype) + + if convert_model_max_length is not None: + config = AutoConfig.from_pretrained(output_path, trust_remote_code=True) + config.model_max_length = convert_model_max_length + config.save_pretrained(output_path) + + +def convert( + checkpoint_path: str, + output_path: str = "./output", + bf16: bool = False, + fp16: bool = False, + convert_model_max_length: Optional[int] = None, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, +): + """Convert checkpoint between MCA and HuggingFace formats. + + Args: + checkpoint_path: Path to the checkpoint to convert + output_path: Path to save the converted checkpoint + bf16: Use bfloat16 precision + fp16: Use float16 precision + convert_model_max_length: Change the model_max_length in hf config.json + tensor_model_parallel_size: Tensor model parallel size + pipeline_model_parallel_size: Pipeline model parallel size + expert_model_parallel_size: Expert model parallel size + virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size + """ + if bf16 and fp16: + raise ValueError("bf16 and fp16 cannot be both True.") + + mca_config_path = os.path.join(checkpoint_path, "mca_config.json") + from_mca = os.path.exists(mca_config_path) + + if not from_mca: + dist_args = DistributingParallelArguments( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size, + ) + + convert_checkpoint_to_mca( + checkpoint_path, + output_path, + dist_args, + bf16=bf16, + fp16=fp16, + ) + else: + convert_mca_to_hf( + checkpoint_path=checkpoint_path, + output_path=output_path, + bf16=bf16, + fp16=fp16, + convert_model_max_length=convert_model_max_length, + ) + + +def main(): + fire.Fire(convert) + + +if __name__ == "__main__": + main() diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 51ea9144..e84e088b 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -56,6 +56,8 @@ LAYERNORM_NAMES = {"norm", "ln"} LLAMABOARD_CONFIG = "llamaboard_config.yaml" +MCA_SUPPORTED_MODELS = {"deepseek_v3", "llama", "mistral", "mixtral", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "qwen3_next"} + METHODS = ["full", "freeze", "lora", "oft"] MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 99b55f55..21265c8e 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -70,6 +70,10 @@ def is_matplotlib_available(): return _is_package_available("matplotlib") +def is_mcore_adapter_available(): + return _is_package_available("mcore_adapter") + + def is_pillow_available(): return _is_package_available("PIL") diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 3130f86e..6a3aaaff 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -461,7 +461,7 @@ class FinetuningArguments( default="sft", metadata={"help": "Which stage will be performed in training."}, ) - finetuning_type: Literal["lora", "freeze", "full"] = field( + finetuning_type: Literal["lora", "oft", "freeze", "full"] = field( default="lora", metadata={"help": "Which fine-tuning method to use."}, ) @@ -473,6 +473,10 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to use the Adam-mini optimizer."}, ) + use_mca: bool = field( + default=False, + metadata={"help": "Whether or not to use MCA (Megatron Core Adapter) training. Controlled by USE_MCA environment variable."}, + ) use_muon: bool = field( default=False, metadata={"help": "Whether or not to use the Muon optimizer."}, diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index e0752818..eca60407 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -32,7 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES, EngineName from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled -from ..extras.packages import is_transformers_version_greater_than +from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than from .data_args import DataArguments from .evaluation_args import EvaluationArguments from .finetuning_args import FinetuningArguments @@ -53,6 +53,13 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] +if is_mcore_adapter_available() and is_env_enabled("USE_MCA"): + from mcore_adapter import TrainingArguments as McaTrainingArguments + _TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments] + _TRAIN_MCA_CLS = tuple[ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments] +else: + _TRAIN_MCA_ARGS = [] + _TRAIN_MCA_CLS = tuple() def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]: r"""Get arguments from the command line or a config file.""" @@ -197,6 +204,27 @@ def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) - return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) +def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS: + parser = HfArgumentParser(_TRAIN_MCA_ARGS) + allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") + model_args, data_args, training_args, finetuning_args, generating_args = _parse_args( + parser, args, allow_extra_keys=allow_extra_keys + ) + + _configure_mca_training_args(training_args, data_args, finetuning_args) + + return model_args, data_args, training_args, finetuning_args, generating_args + + +def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None: + """Patch training args to avoid args checking errors and sync MCA settings.""" + training_args.predict_with_generate = False + training_args.generation_max_length = data_args.cutoff_len + training_args.generation_num_beams = 1 + training_args.use_mca = True + finetuning_args.use_mca = True + + def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: parser = HfArgumentParser(_INFER_ARGS) allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") @@ -216,7 +244,11 @@ def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Ray def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: - model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) + if is_env_enabled("USE_MCA"): + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args) + else: + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) + finetuning_args.use_mca = False # Setup logging if training_args.should_log: diff --git a/src/llamafactory/hparams/training_args.py b/src/llamafactory/hparams/training_args.py index 84b657a9..4c83a93c 100644 --- a/src/llamafactory/hparams/training_args.py +++ b/src/llamafactory/hparams/training_args.py @@ -19,7 +19,20 @@ from typing import Literal, Optional, Union from transformers import Seq2SeqTrainingArguments from transformers.training_args import _convert_str_dict -from ..extras.misc import use_ray +from ..extras.misc import is_env_enabled, use_ray + + +if is_env_enabled("USE_MCA"): + try: + from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments + BaseTrainingArguments = McaSeq2SeqTrainingArguments + except ImportError: + raise ImportError( + "mcore_adapter is required when USE_MCA=1.", + "Please install `mcore_adapter` and its dependencies." + ) +else: + BaseTrainingArguments = Seq2SeqTrainingArguments @dataclass @@ -78,7 +91,7 @@ class RayArguments: @dataclass -class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): +class TrainingArguments(RayArguments, BaseTrainingArguments): r"""Arguments pertaining to the trainer.""" overwrite_output_dir: bool = field( @@ -87,5 +100,5 @@ class TrainingArguments(RayArguments, Seq2SeqTrainingArguments): ) def __post_init__(self): - Seq2SeqTrainingArguments.__post_init__(self) RayArguments.__post_init__(self) + BaseTrainingArguments.__post_init__(self) diff --git a/src/llamafactory/launcher.py b/src/llamafactory/launcher.py index ab955c03..db299ade 100644 --- a/src/llamafactory/launcher.py +++ b/src/llamafactory/launcher.py @@ -54,6 +54,10 @@ def launch(): ) command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" + if is_env_enabled("USE_MCA"): + # force use torchrun + os.environ["FORCE_TORCHRUN"] = "1" + if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): # launch distributed training nnodes = os.getenv("NNODES", "1") diff --git a/src/llamafactory/train/mca/__init__.py b/src/llamafactory/train/mca/__init__.py new file mode 100644 index 00000000..2a68229e --- /dev/null +++ b/src/llamafactory/train/mca/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_dpo, run_pt, run_sft + + +__all__ = ["run_dpo", "run_pt", "run_sft"] + diff --git a/src/llamafactory/train/mca/trainer.py b/src/llamafactory/train/mca/trainer.py new file mode 100644 index 00000000..97cc9b71 --- /dev/null +++ b/src/llamafactory/train/mca/trainer.py @@ -0,0 +1,15 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO override the original trainer diff --git a/src/llamafactory/train/mca/workflow.py b/src/llamafactory/train/mca/workflow.py new file mode 100644 index 00000000..2aa523db --- /dev/null +++ b/src/llamafactory/train/mca/workflow.py @@ -0,0 +1,292 @@ +# Copyright 2025 the ROLL team and the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style.""" + +from __future__ import annotations + +import functools +from collections.abc import Sequence +from copy import deepcopy +from typing import TYPE_CHECKING, Any + +from ...data import ( + SFTDataCollatorWith4DAttentionMask, + get_dataset, + get_template_and_fix_tokenizer, +) +from ...data.collator import ( + PairwiseDataCollatorWithPadding, +) +from ...extras.constants import IGNORE_INDEX, MCA_SUPPORTED_MODELS +from ...extras.logging import get_logger +from ...extras.misc import calculate_tps +from ...extras.packages import is_mcore_adapter_available +from ...extras.ploting import plot_loss +from ...model import load_tokenizer +from ..callbacks import SaveProcessorCallback + + +if not is_mcore_adapter_available(): + raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.") + +from mcore_adapter.models import AutoConfig, AutoModel +from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer +from mcore_adapter.trainer import McaTrainer +from mcore_adapter.trainer.dpo_config import DPOConfig +from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments + + +if TYPE_CHECKING: + from transformers import DataCollatorForSeq2Seq, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments, ModelArguments + + +logger = get_logger(__name__) + + +def _data_collator_wrapper(data_collator: Any): + @functools.wraps(data_collator) + def wrapper(features: Sequence[dict[str, Any]]): + labels_key = [k for k in features[0].keys() if k.endswith("labels")] + input_ids_key = [k for k in features[0].keys() if k.endswith("input_ids")] + for feature in features: + if len(labels_key) == 0: # pt + feature["labels"] = deepcopy(feature["input_ids"])[1:] + for k in labels_key: + feature[k] = feature[k][1:] + for k in input_ids_key: + feature[k] = feature[k][:-1] + for k in ["attention_mask", "position_ids"]: + if k in feature: + feature[k] = feature[k][:-1] + return data_collator(features) + + return wrapper + +def _check_model_support(model_args: ModelArguments): + from transformers import AutoConfig as HfAutoConfig + config = HfAutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code) + if config.model_type not in MCA_SUPPORTED_MODELS: + raise ValueError(f"Model {config.model_type} is not supported by MCA.") + +def run_pt( + model_args: ModelArguments, + data_args: DataArguments, + training_args: McaSeq2SeqTrainingArguments, + finetuning_args: FinetuningArguments, + callbacks: list[TrainerCallback] | None = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + + # dataset needs +1 then cut back due to MCA shift logic + data_args.cutoff_len += 1 + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module) + data_args.cutoff_len -= 1 + + _check_model_support(model_args) + model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) + + from transformers import DataCollatorForSeq2Seq + + data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq( + tokenizer=tokenizer, + pad_to_multiple_of=8, + label_pad_token_id=IGNORE_INDEX, + ) + data_collator = _data_collator_wrapper(data_collator) + + trainer = McaTrainer( + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + ) + + if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: + trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) + + if training_args.do_train: + train_result = trainer.train(training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + keys = ["loss"] + if isinstance(dataset_module.get("eval_dataset"), dict): + keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] + else: + keys += ["eval_loss"] + plot_loss(training_args.output_dir, keys=keys) + + +def run_sft( + model_args: ModelArguments, + data_args: DataArguments, + training_args: McaSeq2SeqTrainingArguments, + finetuning_args: FinetuningArguments, + callbacks: list[TrainerCallback] | None = None, +): + # align packing flags + # TODO: FIX SequencePacking + data_args.neat_packing = training_args.sequence_packing = data_args.neat_packing or training_args.sequence_packing + data_args.packing = data_args.neat_packing or data_args.packing + + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + + # dataset needs +1 then cut back due to MCA shift logic + data_args.cutoff_len += 1 + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) + data_args.cutoff_len -= 1 + + _check_model_support(model_args) + model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) + + # optional freezing for qwen2_vl, qwen2_5_vl + if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_vision_tower: + for name, p in model.named_parameters(): + if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]): + p.requires_grad_(False) + if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_multi_modal_projector: + for name, p in model.named_parameters(): + if any(name.startswith(k) for k in ["multi_modal_projector"]): + p.requires_grad_(False) + if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_language_model: + for name, p in model.named_parameters(): + if any(name.startswith(k) for k in ["embedding", "decoder", "output_layer"]): + p.requires_grad_(False) + + pad_to_max = ( + training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 + ) + data_collator = SFTDataCollatorWith4DAttentionMask( + template=template, + padding="max_length" if pad_to_max else "longest", + max_length=data_args.cutoff_len if pad_to_max else None, + pad_to_multiple_of=64, + label_pad_token_id=IGNORE_INDEX, + **tokenizer_module, + ) + data_collator = _data_collator_wrapper(data_collator) + + trainer = McaTrainer( + model=model, + args=training_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + ) + + if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: + trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) + + train_result = trainer.train(training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + keys = ["loss"] + if isinstance(dataset_module.get("eval_dataset"), dict): + keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] + else: + keys += ["eval_loss"] + plot_loss(training_args.output_dir, keys=keys) + + +def run_dpo( + model_args: ModelArguments, + data_args: DataArguments, + training_args: McaSeq2SeqTrainingArguments, + finetuning_args: FinetuningArguments, + callbacks: list[TrainerCallback] | None = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + + _check_model_support(model_args) + model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) + + if finetuning_args.use_ref_model: + ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args) + ref_model = AutoModel.from_config(ref_config) + ref_model.load_state_dict(model.state_dict()) + else: + ref_model = None + + # dataset needs +1 then cut back due to MCA shift logic + data_args.cutoff_len += 1 + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module) + data_args.cutoff_len -= 1 + + pad_to_max = ( + training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 + ) + dpo_config = DPOConfig( + beta=finetuning_args.pref_beta, + pref_loss=finetuning_args.pref_loss, + label_smoothing=finetuning_args.dpo_label_smoothing, + ) + data_collator = PairwiseDataCollatorWithPadding( + template=template, + pad_to_multiple_of=64, + padding="max_length" if pad_to_max else "longest", + max_length=data_args.cutoff_len if pad_to_max else None, + label_pad_token_id=IGNORE_INDEX, + **tokenizer_module, + ) + data_collator = _data_collator_wrapper(data_collator) + + trainer = McaDPOTrainer( + model=model, + ref_model=ref_model, + args=training_args, + train_config=dpo_config, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + ) + + if "processor" in tokenizer_module and tokenizer_module["processor"] is not None: + trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"])) + + train_result = trainer.train(training_args.resume_from_checkpoint) + trainer.save_model() + if finetuning_args.include_effective_tokens_per_second: + train_result.metrics["effective_tokens_per_sec"] = calculate_tps( + dataset_module["train_dataset"], train_result.metrics, stage="rm" + ) + + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + keys = ["loss", "rewards/accuracies"] + if isinstance(dataset_module.get("eval_dataset"), dict): + keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()] + else: + keys += ["eval_loss"] + + plot_loss(training_args.output_dir, keys=keys) + diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index cd22ba83..a3538ad3 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.misc import infer_optim_dtype -from ..extras.packages import is_ray_available +from ..extras.packages import is_mcore_adapter_available, is_ray_available from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback @@ -66,7 +66,19 @@ def _training_function(config: dict[str, Any]) -> None: callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last - if finetuning_args.stage == "pt": + if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca: + if not is_mcore_adapter_available(): + raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.") + if finetuning_args.stage == "pt": + from .mca import run_pt as run_pt_mca + run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "sft": + from .mca import run_sft as run_sft_mca + run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks) + else: # dpo + from .mca import run_dpo as run_dpo_mca + run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "pt": run_pt(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "sft": run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)