Compare commits

..

No commits in common. "13170577b2dbde8384ee7cb57eac3943f8faf6d8" and "9c0d033a15da97cdf10075beb7317f4ef6123791" have entirely different histories.

16 changed files with 42 additions and 760 deletions

View File

@ -1,77 +0,0 @@
# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10)
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
FROM nvcr.io/nvidia/pytorch:24.05-py3
ENV DEBIAN_FRONTEND=noninteractive
ENV PIP_ROOT_USER_ACTION=ignore
ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
RUN pip install --upgrade pip setuptools wheel --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip uninstall -y torch torchvision torch-tensorrt \
flash_attn transformer-engine \
cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
RUN pip uninstall -y opencv opencv-python opencv-python-headless && \
rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \
pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \
transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \
--trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
# RUN pip install vllm==0.8.4 \
# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
WORKDIR /build
ARG apex_url=git+https://github.com/NVIDIA/apex.git@25.04
RUN pip uninstall -y apex && \
MAX_JOBS=32 NINJA_FLAGS="-j32" NVCC_APPEND_FLAGS="--threads 32" \
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
--config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url}
RUN rm -rf /build
WORKDIR /workspace
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
{ \
echo "deb ${APT_MIRROR} jammy main restricted universe multiverse"; \
echo "deb ${APT_MIRROR} jammy-security main restricted universe multiverse"; \
echo "deb ${APT_MIRROR} jammy-updates main restricted universe multiverse"; \
echo "deb ${APT_MIRROR} jammy-backports main restricted universe multiverse"; \
} > /etc/apt/sources.list
RUN apt-get update && apt-get install -y zip
RUN apt-get install -y openjdk-21-jdk
ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
# pip install LLaMA-Factory
WORKDIR /app
COPY requirements.txt /app/
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
COPY . /app/
RUN pip install -e ".[metrics]" --no-build-isolation
# Expose port 7860 for LLaMA Board
ENV GRADIO_SERVER_PORT=7860
EXPOSE 7860
# Expose port 8000 for API service
ENV API_PORT=8000
EXPOSE 8000
# unset proxy
ENV http_proxy=
ENV https_proxy=

View File

@ -1,29 +0,0 @@
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
image_max_pixels: 262144
video_max_pixels: 16384
do_train: true
stage: sft
finetuning_type: full # only support full for now
dataset: llava_1k_en
preprocessing_num_workers: 8
cutoff_len: 4096
template: qwen2_vl
output_dir: saves/mca/qwen2_vl_full
per_device_train_batch_size: 1
gradient_accumulation_steps: 2
num_train_epochs: 2
learning_rate: 2e-5
logging_steps: 1
save_steps: 100
lr_scheduler_type: cosine
bf16: true
# mcore speed up
tensor_model_parallel_size: 4
sequence_parallel: true
pipeline_model_parallel_size: 2
bias_activation_fusion: true
apply_rope_fusion: true
use_distributed_optimizer: true

View File

@ -1,35 +0,0 @@
model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507
# GPU memory: 8 * 78GB
do_train: true
stage: sft
finetuning_type: full # only support full for now
dataset: alpaca_en_demo
preprocessing_num_workers: 8
cutoff_len: 4096
template: qwen3_nothink
# global batchsize = (8 // 2 // 4) * 8 = 8
output_dir: saves/mca/qwen3_moe_full
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
num_train_epochs: 2
learning_rate: 3e-6
logging_steps: 1
save_steps: 100
lr_scheduler_type: constant
bf16: true
# mcore speed up
tensor_model_parallel_size: 1
sequence_parallel: false
pipeline_model_parallel_size: 4
bias_activation_fusion: true
apply_rope_fusion: true
use_distributed_optimizer: true
overlap_param_gather: true
overlap_grad_reduce: true
moe_grouped_gemm: true
moe_token_dispatcher_type: alltoall
expert_model_parallel_size: 2
recompute_granularity: full

View File

@ -1,125 +0,0 @@
# Copyright 2025 the ROLL team and the LlamaFactory team.
#
# This code is modified from the ROLL library.
# https://github.com/alibaba/ROLL/blob/main/mcore_adapter/tools/convert.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Optional
import fire
import torch
from mcore_adapter.models.converter.post_converter import convert_checkpoint_to_hf, convert_checkpoint_to_mca
from mcore_adapter.training_args import DistributingParallelArguments
from mcore_adapter.utils import get_logger
from transformers import AutoConfig
logger = get_logger(__name__)
def convert_mca_to_hf(
checkpoint_path: str,
output_path: str = "./output",
bf16: bool = False,
fp16: bool = False,
convert_model_max_length: Optional[int] = None,
):
"""Convert megatron checkpoint to HuggingFace format.
Args:
checkpoint_path: Path to the checkpoint to convert
output_path: Path to save the converted checkpoint
bf16: Use bfloat16 precision
fp16: Use float16 precision
convert_model_max_length: Change the model_max_length in hf config.json
"""
if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.")
torch_dtype = None
if bf16:
torch_dtype = torch.bfloat16
elif fp16:
torch_dtype = torch.float16
convert_checkpoint_to_hf(checkpoint_path, output_path, torch_dtype=torch_dtype)
if convert_model_max_length is not None:
config = AutoConfig.from_pretrained(output_path, trust_remote_code=True)
config.model_max_length = convert_model_max_length
config.save_pretrained(output_path)
def convert(
checkpoint_path: str,
output_path: str = "./output",
bf16: bool = False,
fp16: bool = False,
convert_model_max_length: Optional[int] = None,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
expert_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
):
"""Convert checkpoint between MCA and HuggingFace formats.
Args:
checkpoint_path: Path to the checkpoint to convert
output_path: Path to save the converted checkpoint
bf16: Use bfloat16 precision
fp16: Use float16 precision
convert_model_max_length: Change the model_max_length in hf config.json
tensor_model_parallel_size: Tensor model parallel size
pipeline_model_parallel_size: Pipeline model parallel size
expert_model_parallel_size: Expert model parallel size
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
"""
if bf16 and fp16:
raise ValueError("bf16 and fp16 cannot be both True.")
mca_config_path = os.path.join(checkpoint_path, "mca_config.json")
from_mca = os.path.exists(mca_config_path)
if not from_mca:
dist_args = DistributingParallelArguments(
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
expert_model_parallel_size=expert_model_parallel_size,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
)
convert_checkpoint_to_mca(
checkpoint_path,
output_path,
dist_args,
bf16=bf16,
fp16=fp16,
)
else:
convert_mca_to_hf(
checkpoint_path=checkpoint_path,
output_path=output_path,
bf16=bf16,
fp16=fp16,
convert_model_max_length=convert_model_max_length,
)
def main():
fire.Fire(convert)
if __name__ == "__main__":
main()

View File

@ -16,7 +16,6 @@ import gc
import json import json
from typing import Optional from typing import Optional
import av
import fire import fire
from tqdm import tqdm from tqdm import tqdm
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
@ -34,14 +33,6 @@ if is_vllm_available():
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
def _need_video_kwargs(template):
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
if any(t in template for t in NEEDED_TEMPLATE):
return True
return False
def vllm_infer( def vllm_infer(
model_name_or_path: str, model_name_or_path: str,
adapter_name_or_path: str = None, adapter_name_or_path: str = None,
@ -141,7 +132,6 @@ def vllm_infer(
# Store all results in these lists # Store all results in these lists
all_prompts, all_preds, all_labels = [], [], [] all_prompts, all_preds, all_labels = [], [], []
need_video_kwargs = _need_video_kwargs(template)
# Add batch process to avoid the issue of too many files opened # Add batch process to avoid the issue of too many files opened
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"): for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
@ -157,7 +147,6 @@ def vllm_infer(
)["images"] )["images"]
} }
elif batch["videos"][j] is not None: elif batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j] video = batch["videos"][j]
multi_modal_data = { multi_modal_data = {
"video": template_obj.mm_plugin._regularize_videos( "video": template_obj.mm_plugin._regularize_videos(
@ -168,25 +157,6 @@ def vllm_infer(
video_maxlen=video_maxlen, video_maxlen=video_maxlen,
)["videos"] )["videos"]
} }
if need_video_kwargs:
container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
video_stream, video_fps, video_maxlen
)
total_frames = video_stream.frames
video_metadata_kwargs = {
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
"do_sample_frames": False,
"total_num_frames": total_frames,
}
video_metadata = dict(
fps=video_fps,
frames_indices=sampling_indices,
total_num_frames=total_frames,
video_backend="opencv",
)
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None: elif batch["audios"][j] is not None:
audio = batch["audios"][j] audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios( audio_data = template_obj.mm_plugin._regularize_audios(
@ -197,11 +167,7 @@ def vllm_infer(
else: else:
multi_modal_data = None multi_modal_data = None
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data} vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data)
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens)) prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
labels.append( labels.append(
tokenizer.decode( tokenizer.decode(

View File

@ -31,7 +31,7 @@ from transformers.models.mllama.processing_mllama import (
convert_sparse_cross_attention_mask_to_dense, convert_sparse_cross_attention_mask_to_dense,
get_cross_attention_token_mask, get_cross_attention_token_mask,
) )
from typing_extensions import NotRequired, override from typing_extensions import override
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import ( from ..extras.packages import (
@ -77,18 +77,6 @@ if TYPE_CHECKING:
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]] VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
AudioInput = Union[str, BinaryIO, NDArray] AudioInput = Union[str, BinaryIO, NDArray]
class RegularizedImageOutput(TypedDict):
images: list[ImageObject]
class RegularizedVideoOutput(TypedDict):
videos: list[list[ImageObject]]
durations: list[float]
fps_per_video: NotRequired[list[float]]
class RegularizedAudioOutput(TypedDict):
audios: list[NDArray]
sampling_rates: list[float]
class MMProcessor(ProcessorMixin): class MMProcessor(ProcessorMixin):
patch_size: int patch_size: int
image_seq_length: int image_seq_length: int
@ -256,7 +244,7 @@ class MMPluginMixin:
sample_frames = min(total_frames, video_maxlen, sample_frames) sample_frames = min(total_frames, video_maxlen, sample_frames)
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32) return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput": def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
r"""Regularize images to avoid error. Including reading and pre-processing.""" r"""Regularize images to avoid error. Including reading and pre-processing."""
results = [] results = []
for image in images: for image in images:
@ -277,10 +265,9 @@ class MMPluginMixin:
return {"images": results} return {"images": results}
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
r"""Regularizes videos to avoid error. Including reading, resizing and converting.""" r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results = [] results = []
durations = []
for video in videos: for video in videos:
frames: list[ImageObject] = [] frames: list[ImageObject] = []
if _check_video_is_nested_images(video): if _check_video_is_nested_images(video):
@ -288,7 +275,6 @@ class MMPluginMixin:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.") raise ValueError("Invalid image found in video frames.")
frames = video frames = video
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@ -298,19 +284,14 @@ class MMPluginMixin:
if frame_idx in sample_indices: if frame_idx in sample_indices:
frames.append(frame.to_image()) frames.append(frame.to_image())
if video_stream.duration is None:
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
durations.append(float(video_stream.duration * video_stream.time_base))
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
return {"videos": results, "durations": durations} return {"videos": results}
def _regularize_audios( def _regularize_audios(
self, audios: list["AudioInput"], sampling_rate: float, **kwargs self, audios: list["AudioInput"], sampling_rate: float, **kwargs
) -> "RegularizedAudioOutput": ) -> dict[str, Union[list["NDArray"], list[float]]]:
r"""Regularizes audios to avoid error. Including reading and resampling.""" r"""Regularizes audios to avoid error. Including reading and resampling."""
results, sampling_rates = [], [] results, sampling_rates = [], []
for audio in audios: for audio in audios:
@ -1437,8 +1418,10 @@ class Qwen2VLPlugin(BasePlugin):
return image return image
@override @override
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": def _regularize_videos(
results, fps_per_video, durations = [], [], [] self, videos: list["VideoInput"], **kwargs
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
results, fps_per_video = [], []
for video in videos: for video in videos:
frames: list[ImageObject] = [] frames: list[ImageObject] = []
if _check_video_is_nested_images(video): if _check_video_is_nested_images(video):
@ -1448,7 +1431,6 @@ class Qwen2VLPlugin(BasePlugin):
frames = video frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0)) fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@ -1460,10 +1442,8 @@ class Qwen2VLPlugin(BasePlugin):
if video_stream.duration is None: if video_stream.duration is None:
fps_per_video.append(kwargs.get("video_fps", 2.0)) fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else: else:
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base)) fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
durations.append(float(video_stream.duration * video_stream.time_base))
if len(frames) % 2 != 0: if len(frames) % 2 != 0:
frames.append(frames[-1]) frames.append(frames[-1])
@ -1471,7 +1451,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations} return {"videos": results, "fps_per_video": fps_per_video}
@override @override
def _get_mm_inputs( def _get_mm_inputs(
@ -1585,8 +1565,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
video_metadata = [ video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)} {"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
for video, duration in zip(videos["videos"], videos["durations"]) for video in videos["videos"]
] ]
mm_inputs.update( mm_inputs.update(
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True) video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
@ -1642,7 +1622,6 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
num_image_tokens += 1 num_image_tokens += 1
while VIDEO_PLACEHOLDER in content: while VIDEO_PLACEHOLDER in content:
if self.expand_mm_tokens:
metadata = video_metadata[idx] metadata = video_metadata[idx]
timestamps = processor._calculate_timestamps( timestamps = processor._calculate_timestamps(
metadata.frames_indices, metadata.frames_indices,
@ -1662,7 +1641,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}" f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
) )
video_structure += frame_structure video_structure += frame_structure
else:
if not self.expand_mm_tokens:
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}" video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1) content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
@ -1704,8 +1684,7 @@ class GLM4VPlugin(Qwen2VLPlugin):
) )
# prepare video metadata # prepare video metadata
video_metadata = [ video_metadata = [
{"fps": 2, "duration": duration, "total_frames": len(video)} {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
for video, duration in zip(video_data["videos"], video_data["durations"])
] ]
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata)) mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))

View File

@ -56,8 +56,6 @@ LAYERNORM_NAMES = {"norm", "ln"}
LLAMABOARD_CONFIG = "llamaboard_config.yaml" LLAMABOARD_CONFIG = "llamaboard_config.yaml"
MCA_SUPPORTED_MODELS = {"deepseek_v3", "llama", "mistral", "mixtral", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "qwen3_next"}
METHODS = ["full", "freeze", "lora", "oft"] METHODS = ["full", "freeze", "lora", "oft"]
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}

View File

@ -70,10 +70,6 @@ def is_matplotlib_available():
return _is_package_available("matplotlib") return _is_package_available("matplotlib")
def is_mcore_adapter_available():
return _is_package_available("mcore_adapter")
def is_pillow_available(): def is_pillow_available():
return _is_package_available("PIL") return _is_package_available("PIL")

View File

@ -461,7 +461,7 @@ class FinetuningArguments(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."}, metadata={"help": "Which stage will be performed in training."},
) )
finetuning_type: Literal["lora", "oft", "freeze", "full"] = field( finetuning_type: Literal["lora", "freeze", "full"] = field(
default="lora", default="lora",
metadata={"help": "Which fine-tuning method to use."}, metadata={"help": "Which fine-tuning method to use."},
) )
@ -473,10 +473,6 @@ class FinetuningArguments(
default=False, default=False,
metadata={"help": "Whether or not to use the Adam-mini optimizer."}, metadata={"help": "Whether or not to use the Adam-mini optimizer."},
) )
use_mca: bool = field(
default=False,
metadata={"help": "Whether or not to use MCA (Megatron Core Adapter) training. Controlled by USE_MCA environment variable."},
)
use_muon: bool = field( use_muon: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use the Muon optimizer."}, metadata={"help": "Whether or not to use the Muon optimizer."},

View File

@ -32,7 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
from ..extras import logging from ..extras import logging
from ..extras.constants import CHECKPOINT_NAMES, EngineName from ..extras.constants import CHECKPOINT_NAMES, EngineName
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than from ..extras.packages import is_transformers_version_greater_than
from .data_args import DataArguments from .data_args import DataArguments
from .evaluation_args import EvaluationArguments from .evaluation_args import EvaluationArguments
from .finetuning_args import FinetuningArguments from .finetuning_args import FinetuningArguments
@ -53,13 +53,6 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments] _EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
from mcore_adapter import TrainingArguments as McaTrainingArguments
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
_TRAIN_MCA_CLS = tuple[ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
else:
_TRAIN_MCA_ARGS = []
_TRAIN_MCA_CLS = tuple()
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]: def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
r"""Get arguments from the command line or a config file.""" r"""Get arguments from the command line or a config file."""
@ -204,27 +197,6 @@ def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys) return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
parser, args, allow_extra_keys=allow_extra_keys
)
_configure_mca_training_args(training_args, data_args, finetuning_args)
return model_args, data_args, training_args, finetuning_args, generating_args
def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None:
"""Patch training args to avoid args checking errors and sync MCA settings."""
training_args.predict_with_generate = False
training_args.generation_max_length = data_args.cutoff_len
training_args.generation_num_beams = 1
training_args.use_mca = True
finetuning_args.use_mca = True
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS: def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
parser = HfArgumentParser(_INFER_ARGS) parser = HfArgumentParser(_INFER_ARGS)
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS") allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
@ -244,11 +216,7 @@ def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Ray
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS: def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
if is_env_enabled("USE_MCA"):
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
else:
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
finetuning_args.use_mca = False
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:

View File

@ -19,20 +19,7 @@ from typing import Literal, Optional, Union
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
from ..extras.misc import is_env_enabled, use_ray from ..extras.misc import use_ray
if is_env_enabled("USE_MCA"):
try:
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
BaseTrainingArguments = McaSeq2SeqTrainingArguments
except ImportError:
raise ImportError(
"mcore_adapter is required when USE_MCA=1.",
"Please install `mcore_adapter` and its dependencies."
)
else:
BaseTrainingArguments = Seq2SeqTrainingArguments
@dataclass @dataclass
@ -91,7 +78,7 @@ class RayArguments:
@dataclass @dataclass
class TrainingArguments(RayArguments, BaseTrainingArguments): class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
r"""Arguments pertaining to the trainer.""" r"""Arguments pertaining to the trainer."""
overwrite_output_dir: bool = field( overwrite_output_dir: bool = field(
@ -100,5 +87,5 @@ class TrainingArguments(RayArguments, BaseTrainingArguments):
) )
def __post_init__(self): def __post_init__(self):
Seq2SeqTrainingArguments.__post_init__(self)
RayArguments.__post_init__(self) RayArguments.__post_init__(self)
BaseTrainingArguments.__post_init__(self)

View File

@ -54,10 +54,6 @@ def launch():
) )
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help" command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
if is_env_enabled("USE_MCA"):
# force use torchrun
os.environ["FORCE_TORCHRUN"] = "1"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())): if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training # launch distributed training
nnodes = os.getenv("NNODES", "1") nnodes = os.getenv("NNODES", "1")

View File

@ -1,19 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_dpo, run_pt, run_sft
__all__ = ["run_dpo", "run_pt", "run_sft"]

View File

@ -1,15 +0,0 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO override the original trainer

View File

@ -1,292 +0,0 @@
# Copyright 2025 the ROLL team and the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
from __future__ import annotations
import functools
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING, Any
from ...data import (
SFTDataCollatorWith4DAttentionMask,
get_dataset,
get_template_and_fix_tokenizer,
)
from ...data.collator import (
PairwiseDataCollatorWithPadding,
)
from ...extras.constants import IGNORE_INDEX, MCA_SUPPORTED_MODELS
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.packages import is_mcore_adapter_available
from ...extras.ploting import plot_loss
from ...model import load_tokenizer
from ..callbacks import SaveProcessorCallback
if not is_mcore_adapter_available():
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
from mcore_adapter.models import AutoConfig, AutoModel
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
from mcore_adapter.trainer import McaTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
if TYPE_CHECKING:
from transformers import DataCollatorForSeq2Seq, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
logger = get_logger(__name__)
def _data_collator_wrapper(data_collator: Any):
@functools.wraps(data_collator)
def wrapper(features: Sequence[dict[str, Any]]):
labels_key = [k for k in features[0].keys() if k.endswith("labels")]
input_ids_key = [k for k in features[0].keys() if k.endswith("input_ids")]
for feature in features:
if len(labels_key) == 0: # pt
feature["labels"] = deepcopy(feature["input_ids"])[1:]
for k in labels_key:
feature[k] = feature[k][1:]
for k in input_ids_key:
feature[k] = feature[k][:-1]
for k in ["attention_mask", "position_ids"]:
if k in feature:
feature[k] = feature[k][:-1]
return data_collator(features)
return wrapper
def _check_model_support(model_args: ModelArguments):
from transformers import AutoConfig as HfAutoConfig
config = HfAutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
if config.model_type not in MCA_SUPPORTED_MODELS:
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
def run_pt(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
# dataset needs +1 then cut back due to MCA shift logic
data_args.cutoff_len += 1
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
data_args.cutoff_len -= 1
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
from transformers import DataCollatorForSeq2Seq
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8,
label_pad_token_id=IGNORE_INDEX,
)
data_collator = _data_collator_wrapper(data_collator)
trainer = McaTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
)
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
if training_args.do_train:
train_result = trainer.train(training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
def run_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
):
# align packing flags
# TODO: FIX SequencePacking
data_args.neat_packing = training_args.sequence_packing = data_args.neat_packing or training_args.sequence_packing
data_args.packing = data_args.neat_packing or data_args.packing
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
# dataset needs +1 then cut back due to MCA shift logic
data_args.cutoff_len += 1
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
data_args.cutoff_len -= 1
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_vision_tower:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]):
p.requires_grad_(False)
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_multi_modal_projector:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in ["multi_modal_projector"]):
p.requires_grad_(False)
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_language_model:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in ["embedding", "decoder", "output_layer"]):
p.requires_grad_(False)
pad_to_max = (
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
)
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
padding="max_length" if pad_to_max else "longest",
max_length=data_args.cutoff_len if pad_to_max else None,
pad_to_multiple_of=64,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
data_collator = _data_collator_wrapper(data_collator)
trainer = McaTrainer(
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
)
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
train_result = trainer.train(training_args.resume_from_checkpoint)
trainer.save_model()
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)
def run_dpo(
model_args: ModelArguments,
data_args: DataArguments,
training_args: McaSeq2SeqTrainingArguments,
finetuning_args: FinetuningArguments,
callbacks: list[TrainerCallback] | None = None,
):
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
_check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
if finetuning_args.use_ref_model:
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
ref_model = AutoModel.from_config(ref_config)
ref_model.load_state_dict(model.state_dict())
else:
ref_model = None
# dataset needs +1 then cut back due to MCA shift logic
data_args.cutoff_len += 1
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
data_args.cutoff_len -= 1
pad_to_max = (
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
)
dpo_config = DPOConfig(
beta=finetuning_args.pref_beta,
pref_loss=finetuning_args.pref_loss,
label_smoothing=finetuning_args.dpo_label_smoothing,
)
data_collator = PairwiseDataCollatorWithPadding(
template=template,
pad_to_multiple_of=64,
padding="max_length" if pad_to_max else "longest",
max_length=data_args.cutoff_len if pad_to_max else None,
label_pad_token_id=IGNORE_INDEX,
**tokenizer_module,
)
data_collator = _data_collator_wrapper(data_collator)
trainer = McaDPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
train_config=dpo_config,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**dataset_module,
)
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
train_result = trainer.train(training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="rm"
)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss", "rewards/accuracies"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
else:
keys += ["eval_loss"]
plot_loss(training_args.output_dir, keys=keys)

View File

@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import infer_optim_dtype from ..extras.misc import infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available from ..extras.packages import is_ray_available
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@ -66,19 +66,7 @@ def _training_function(config: dict[str, Any]) -> None:
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available():
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
if finetuning_args.stage == "pt": if finetuning_args.stage == "pt":
from .mca import run_pt as run_pt_mca
run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft":
from .mca import run_sft as run_sft_mca
run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
else: # dpo
from .mca import run_dpo as run_dpo_mca
run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks) run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif finetuning_args.stage == "sft": elif finetuning_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)