mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 10:22:15 +08:00
Compare commits
2 Commits
9c0d033a15
...
13170577b2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13170577b2 | ||
|
|
129e918106 |
77
docker/docker-cuda/Dockerfile.megatron
Normal file
77
docker/docker-cuda/Dockerfile.megatron
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# NVIDIA official image (ubuntu-22.04 + cuda-12.4 + python-3.10)
|
||||||
|
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-08.html
|
||||||
|
FROM nvcr.io/nvidia/pytorch:24.05-py3
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PIP_ROOT_USER_ACTION=ignore
|
||||||
|
ENV PYPI_MIRROR=https://mirrors.aliyun.com/pypi/simple/
|
||||||
|
ENV PYPI_TRUSTED_HOST=mirrors.aliyun.com
|
||||||
|
ENV APT_MIRROR=https://mirrors.tuna.tsinghua.edu.cn/ubuntu/
|
||||||
|
|
||||||
|
RUN pip install --upgrade pip setuptools wheel --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||||
|
|
||||||
|
RUN pip uninstall -y torch torchvision torch-tensorrt \
|
||||||
|
flash_attn transformer-engine \
|
||||||
|
cudf dask-cuda cugraph cugraph-service-server cuml raft-dask cugraph-dgl cugraph-pyg dask-cudf
|
||||||
|
|
||||||
|
RUN pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124
|
||||||
|
|
||||||
|
RUN pip uninstall -y opencv opencv-python opencv-python-headless && \
|
||||||
|
rm -rf /usr/local/lib/python3.10/dist-packages/cv2/ && \
|
||||||
|
pip install opencv-python-headless==4.11.0.86 --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||||
|
|
||||||
|
RUN pip install "numpy==1.26.4" "optree>=0.13.0" "spacy==3.7.5" "weasel==0.4.1" \
|
||||||
|
transformer-engine[pytorch]==2.2.0 megatron-core==0.13.0 deepspeed==0.16.4 \
|
||||||
|
--trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||||
|
|
||||||
|
RUN pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.2.post1/flash_attn-2.7.2.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
# RUN pip install vllm==0.8.4 \
|
||||||
|
# --trusted-host ${PYPI_TRUSTED_HOST} --index-url ${PYPI_MIRROR}
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
ARG apex_url=git+https://github.com/NVIDIA/apex.git@25.04
|
||||||
|
RUN pip uninstall -y apex && \
|
||||||
|
MAX_JOBS=32 NINJA_FLAGS="-j32" NVCC_APPEND_FLAGS="--threads 32" \
|
||||||
|
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||||
|
--config-settings "--build-option=--cpp_ext --cuda_ext --parallel 32" ${apex_url}
|
||||||
|
|
||||||
|
RUN rm -rf /build
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
RUN cp /etc/apt/sources.list /etc/apt/sources.list.bak && \
|
||||||
|
{ \
|
||||||
|
echo "deb ${APT_MIRROR} jammy main restricted universe multiverse"; \
|
||||||
|
echo "deb ${APT_MIRROR} jammy-security main restricted universe multiverse"; \
|
||||||
|
echo "deb ${APT_MIRROR} jammy-updates main restricted universe multiverse"; \
|
||||||
|
echo "deb ${APT_MIRROR} jammy-backports main restricted universe multiverse"; \
|
||||||
|
} > /etc/apt/sources.list
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y zip
|
||||||
|
|
||||||
|
RUN apt-get install -y openjdk-21-jdk
|
||||||
|
ENV JAVA_HOME /usr/lib/jvm/java-21-openjdk-amd64
|
||||||
|
|
||||||
|
# pip install LLaMA-Factory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY requirements.txt /app/
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
RUN pip install "git+https://github.com/alibaba/roll.git#subdirectory=mcore_adapter"
|
||||||
|
|
||||||
|
COPY . /app/
|
||||||
|
RUN pip install -e ".[metrics]" --no-build-isolation
|
||||||
|
|
||||||
|
# Expose port 7860 for LLaMA Board
|
||||||
|
ENV GRADIO_SERVER_PORT=7860
|
||||||
|
EXPOSE 7860
|
||||||
|
|
||||||
|
# Expose port 8000 for API service
|
||||||
|
ENV API_PORT=8000
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# unset proxy
|
||||||
|
ENV http_proxy=
|
||||||
|
ENV https_proxy=
|
||||||
29
examples/megatron/qwen2_vl_full.yaml
Normal file
29
examples/megatron/qwen2_vl_full.yaml
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
|
image_max_pixels: 262144
|
||||||
|
video_max_pixels: 16384
|
||||||
|
|
||||||
|
do_train: true
|
||||||
|
stage: sft
|
||||||
|
finetuning_type: full # only support full for now
|
||||||
|
dataset: llava_1k_en
|
||||||
|
preprocessing_num_workers: 8
|
||||||
|
cutoff_len: 4096
|
||||||
|
template: qwen2_vl
|
||||||
|
|
||||||
|
output_dir: saves/mca/qwen2_vl_full
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
num_train_epochs: 2
|
||||||
|
learning_rate: 2e-5
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 100
|
||||||
|
lr_scheduler_type: cosine
|
||||||
|
bf16: true
|
||||||
|
|
||||||
|
# mcore speed up
|
||||||
|
tensor_model_parallel_size: 4
|
||||||
|
sequence_parallel: true
|
||||||
|
pipeline_model_parallel_size: 2
|
||||||
|
bias_activation_fusion: true
|
||||||
|
apply_rope_fusion: true
|
||||||
|
use_distributed_optimizer: true
|
||||||
35
examples/megatron/qwen3_moe_full.yaml
Normal file
35
examples/megatron/qwen3_moe_full.yaml
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
model_name_or_path: Qwen/Qwen3-30B-A3B-Instruct-2507
|
||||||
|
|
||||||
|
# GPU memory: 8 * 78GB
|
||||||
|
do_train: true
|
||||||
|
stage: sft
|
||||||
|
finetuning_type: full # only support full for now
|
||||||
|
dataset: alpaca_en_demo
|
||||||
|
preprocessing_num_workers: 8
|
||||||
|
cutoff_len: 4096
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# global batchsize = (8 // 2 // 4) * 8 = 8
|
||||||
|
output_dir: saves/mca/qwen3_moe_full
|
||||||
|
per_device_train_batch_size: 1
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
num_train_epochs: 2
|
||||||
|
learning_rate: 3e-6
|
||||||
|
logging_steps: 1
|
||||||
|
save_steps: 100
|
||||||
|
lr_scheduler_type: constant
|
||||||
|
bf16: true
|
||||||
|
|
||||||
|
# mcore speed up
|
||||||
|
tensor_model_parallel_size: 1
|
||||||
|
sequence_parallel: false
|
||||||
|
pipeline_model_parallel_size: 4
|
||||||
|
bias_activation_fusion: true
|
||||||
|
apply_rope_fusion: true
|
||||||
|
use_distributed_optimizer: true
|
||||||
|
overlap_param_gather: true
|
||||||
|
overlap_grad_reduce: true
|
||||||
|
moe_grouped_gemm: true
|
||||||
|
moe_token_dispatcher_type: alltoall
|
||||||
|
expert_model_parallel_size: 2
|
||||||
|
recompute_granularity: full
|
||||||
125
scripts/megatron_merge.py
Normal file
125
scripts/megatron_merge.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
# Copyright 2025 the ROLL team and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is modified from the ROLL library.
|
||||||
|
# https://github.com/alibaba/ROLL/blob/main/mcore_adapter/tools/convert.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import torch
|
||||||
|
from mcore_adapter.models.converter.post_converter import convert_checkpoint_to_hf, convert_checkpoint_to_mca
|
||||||
|
from mcore_adapter.training_args import DistributingParallelArguments
|
||||||
|
from mcore_adapter.utils import get_logger
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_mca_to_hf(
|
||||||
|
checkpoint_path: str,
|
||||||
|
output_path: str = "./output",
|
||||||
|
bf16: bool = False,
|
||||||
|
fp16: bool = False,
|
||||||
|
convert_model_max_length: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Convert megatron checkpoint to HuggingFace format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: Path to the checkpoint to convert
|
||||||
|
output_path: Path to save the converted checkpoint
|
||||||
|
bf16: Use bfloat16 precision
|
||||||
|
fp16: Use float16 precision
|
||||||
|
convert_model_max_length: Change the model_max_length in hf config.json
|
||||||
|
"""
|
||||||
|
if bf16 and fp16:
|
||||||
|
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||||
|
|
||||||
|
torch_dtype = None
|
||||||
|
if bf16:
|
||||||
|
torch_dtype = torch.bfloat16
|
||||||
|
elif fp16:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
convert_checkpoint_to_hf(checkpoint_path, output_path, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
if convert_model_max_length is not None:
|
||||||
|
config = AutoConfig.from_pretrained(output_path, trust_remote_code=True)
|
||||||
|
config.model_max_length = convert_model_max_length
|
||||||
|
config.save_pretrained(output_path)
|
||||||
|
|
||||||
|
|
||||||
|
def convert(
|
||||||
|
checkpoint_path: str,
|
||||||
|
output_path: str = "./output",
|
||||||
|
bf16: bool = False,
|
||||||
|
fp16: bool = False,
|
||||||
|
convert_model_max_length: Optional[int] = None,
|
||||||
|
tensor_model_parallel_size: int = 1,
|
||||||
|
pipeline_model_parallel_size: int = 1,
|
||||||
|
expert_model_parallel_size: int = 1,
|
||||||
|
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
||||||
|
):
|
||||||
|
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_path: Path to the checkpoint to convert
|
||||||
|
output_path: Path to save the converted checkpoint
|
||||||
|
bf16: Use bfloat16 precision
|
||||||
|
fp16: Use float16 precision
|
||||||
|
convert_model_max_length: Change the model_max_length in hf config.json
|
||||||
|
tensor_model_parallel_size: Tensor model parallel size
|
||||||
|
pipeline_model_parallel_size: Pipeline model parallel size
|
||||||
|
expert_model_parallel_size: Expert model parallel size
|
||||||
|
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||||
|
"""
|
||||||
|
if bf16 and fp16:
|
||||||
|
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||||
|
|
||||||
|
mca_config_path = os.path.join(checkpoint_path, "mca_config.json")
|
||||||
|
from_mca = os.path.exists(mca_config_path)
|
||||||
|
|
||||||
|
if not from_mca:
|
||||||
|
dist_args = DistributingParallelArguments(
|
||||||
|
tensor_model_parallel_size=tensor_model_parallel_size,
|
||||||
|
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||||
|
expert_model_parallel_size=expert_model_parallel_size,
|
||||||
|
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
convert_checkpoint_to_mca(
|
||||||
|
checkpoint_path,
|
||||||
|
output_path,
|
||||||
|
dist_args,
|
||||||
|
bf16=bf16,
|
||||||
|
fp16=fp16,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
convert_mca_to_hf(
|
||||||
|
checkpoint_path=checkpoint_path,
|
||||||
|
output_path=output_path,
|
||||||
|
bf16=bf16,
|
||||||
|
fp16=fp16,
|
||||||
|
convert_model_max_length=convert_model_max_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
fire.Fire(convert)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -16,6 +16,7 @@ import gc
|
|||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import av
|
||||||
import fire
|
import fire
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
@ -33,6 +34,14 @@ if is_vllm_available():
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
|
def _need_video_kwargs(template):
|
||||||
|
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
|
||||||
|
if any(t in template for t in NEEDED_TEMPLATE):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def vllm_infer(
|
def vllm_infer(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
adapter_name_or_path: str = None,
|
adapter_name_or_path: str = None,
|
||||||
@ -132,6 +141,7 @@ def vllm_infer(
|
|||||||
|
|
||||||
# Store all results in these lists
|
# Store all results in these lists
|
||||||
all_prompts, all_preds, all_labels = [], [], []
|
all_prompts, all_preds, all_labels = [], [], []
|
||||||
|
need_video_kwargs = _need_video_kwargs(template)
|
||||||
|
|
||||||
# Add batch process to avoid the issue of too many files opened
|
# Add batch process to avoid the issue of too many files opened
|
||||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||||
@ -147,6 +157,7 @@ def vllm_infer(
|
|||||||
)["images"]
|
)["images"]
|
||||||
}
|
}
|
||||||
elif batch["videos"][j] is not None:
|
elif batch["videos"][j] is not None:
|
||||||
|
video_metadata, video_metadata_kwargs = None, None
|
||||||
video = batch["videos"][j]
|
video = batch["videos"][j]
|
||||||
multi_modal_data = {
|
multi_modal_data = {
|
||||||
"video": template_obj.mm_plugin._regularize_videos(
|
"video": template_obj.mm_plugin._regularize_videos(
|
||||||
@ -157,6 +168,25 @@ def vllm_infer(
|
|||||||
video_maxlen=video_maxlen,
|
video_maxlen=video_maxlen,
|
||||||
)["videos"]
|
)["videos"]
|
||||||
}
|
}
|
||||||
|
if need_video_kwargs:
|
||||||
|
container = av.open(video[0], "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
|
||||||
|
video_stream, video_fps, video_maxlen
|
||||||
|
)
|
||||||
|
total_frames = video_stream.frames
|
||||||
|
video_metadata_kwargs = {
|
||||||
|
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
|
||||||
|
"do_sample_frames": False,
|
||||||
|
"total_num_frames": total_frames,
|
||||||
|
}
|
||||||
|
video_metadata = dict(
|
||||||
|
fps=video_fps,
|
||||||
|
frames_indices=sampling_indices,
|
||||||
|
total_num_frames=total_frames,
|
||||||
|
video_backend="opencv",
|
||||||
|
)
|
||||||
|
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||||
elif batch["audios"][j] is not None:
|
elif batch["audios"][j] is not None:
|
||||||
audio = batch["audios"][j]
|
audio = batch["audios"][j]
|
||||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||||
@ -167,7 +197,11 @@ def vllm_infer(
|
|||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
|
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
||||||
|
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
||||||
|
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||||
|
|
||||||
|
vllm_inputs.append(vllm_input_data)
|
||||||
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
|
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
|
||||||
labels.append(
|
labels.append(
|
||||||
tokenizer.decode(
|
tokenizer.decode(
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from transformers.models.mllama.processing_mllama import (
|
|||||||
convert_sparse_cross_attention_mask_to_dense,
|
convert_sparse_cross_attention_mask_to_dense,
|
||||||
get_cross_attention_token_mask,
|
get_cross_attention_token_mask,
|
||||||
)
|
)
|
||||||
from typing_extensions import override
|
from typing_extensions import NotRequired, override
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.packages import (
|
from ..extras.packages import (
|
||||||
@ -77,6 +77,18 @@ if TYPE_CHECKING:
|
|||||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||||
AudioInput = Union[str, BinaryIO, NDArray]
|
AudioInput = Union[str, BinaryIO, NDArray]
|
||||||
|
|
||||||
|
class RegularizedImageOutput(TypedDict):
|
||||||
|
images: list[ImageObject]
|
||||||
|
|
||||||
|
class RegularizedVideoOutput(TypedDict):
|
||||||
|
videos: list[list[ImageObject]]
|
||||||
|
durations: list[float]
|
||||||
|
fps_per_video: NotRequired[list[float]]
|
||||||
|
|
||||||
|
class RegularizedAudioOutput(TypedDict):
|
||||||
|
audios: list[NDArray]
|
||||||
|
sampling_rates: list[float]
|
||||||
|
|
||||||
class MMProcessor(ProcessorMixin):
|
class MMProcessor(ProcessorMixin):
|
||||||
patch_size: int
|
patch_size: int
|
||||||
image_seq_length: int
|
image_seq_length: int
|
||||||
@ -244,7 +256,7 @@ class MMPluginMixin:
|
|||||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
|
|
||||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
|
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
||||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||||
results = []
|
results = []
|
||||||
for image in images:
|
for image in images:
|
||||||
@ -265,9 +277,10 @@ class MMPluginMixin:
|
|||||||
|
|
||||||
return {"images": results}
|
return {"images": results}
|
||||||
|
|
||||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||||
results = []
|
results = []
|
||||||
|
durations = []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
if _check_video_is_nested_images(video):
|
if _check_video_is_nested_images(video):
|
||||||
@ -275,6 +288,7 @@ class MMPluginMixin:
|
|||||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||||
raise ValueError("Invalid image found in video frames.")
|
raise ValueError("Invalid image found in video frames.")
|
||||||
frames = video
|
frames = video
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@ -284,14 +298,19 @@ class MMPluginMixin:
|
|||||||
if frame_idx in sample_indices:
|
if frame_idx in sample_indices:
|
||||||
frames.append(frame.to_image())
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
if video_stream.duration is None:
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
else:
|
||||||
|
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results}
|
return {"videos": results, "durations": durations}
|
||||||
|
|
||||||
def _regularize_audios(
|
def _regularize_audios(
|
||||||
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
||||||
) -> dict[str, Union[list["NDArray"], list[float]]]:
|
) -> "RegularizedAudioOutput":
|
||||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||||
results, sampling_rates = [], []
|
results, sampling_rates = [], []
|
||||||
for audio in audios:
|
for audio in audios:
|
||||||
@ -1418,10 +1437,8 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _regularize_videos(
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
self, videos: list["VideoInput"], **kwargs
|
results, fps_per_video, durations = [], [], []
|
||||||
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
|
|
||||||
results, fps_per_video = [], []
|
|
||||||
for video in videos:
|
for video in videos:
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
if _check_video_is_nested_images(video):
|
if _check_video_is_nested_images(video):
|
||||||
@ -1431,6 +1448,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
frames = video
|
frames = video
|
||||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@ -1442,8 +1460,10 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
if video_stream.duration is None:
|
if video_stream.duration is None:
|
||||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||||
|
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
if len(frames) % 2 != 0:
|
if len(frames) % 2 != 0:
|
||||||
frames.append(frames[-1])
|
frames.append(frames[-1])
|
||||||
@ -1451,7 +1471,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results, "fps_per_video": fps_per_video}
|
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
@ -1565,8 +1585,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
|
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
|
||||||
for video in videos["videos"]
|
for video, duration in zip(videos["videos"], videos["durations"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(
|
||||||
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
|
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
|
||||||
@ -1622,6 +1642,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
if self.expand_mm_tokens:
|
||||||
metadata = video_metadata[idx]
|
metadata = video_metadata[idx]
|
||||||
timestamps = processor._calculate_timestamps(
|
timestamps = processor._calculate_timestamps(
|
||||||
metadata.frames_indices,
|
metadata.frames_indices,
|
||||||
@ -1641,8 +1662,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
||||||
)
|
)
|
||||||
video_structure += frame_structure
|
video_structure += frame_structure
|
||||||
|
else:
|
||||||
if not self.expand_mm_tokens:
|
|
||||||
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
||||||
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
||||||
@ -1684,7 +1704,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
|||||||
)
|
)
|
||||||
# prepare video metadata
|
# prepare video metadata
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
|
{"fps": 2, "duration": duration, "total_frames": len(video)}
|
||||||
|
for video, duration in zip(video_data["videos"], video_data["durations"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
||||||
|
|
||||||
|
|||||||
@ -56,6 +56,8 @@ LAYERNORM_NAMES = {"norm", "ln"}
|
|||||||
|
|
||||||
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
||||||
|
|
||||||
|
MCA_SUPPORTED_MODELS = {"deepseek_v3", "llama", "mistral", "mixtral", "qwen2", "qwen2_vl", "qwen2_5_vl", "qwen3", "qwen3_moe", "qwen3_next"}
|
||||||
|
|
||||||
METHODS = ["full", "freeze", "lora", "oft"]
|
METHODS = ["full", "freeze", "lora", "oft"]
|
||||||
|
|
||||||
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
||||||
|
|||||||
@ -70,6 +70,10 @@ def is_matplotlib_available():
|
|||||||
return _is_package_available("matplotlib")
|
return _is_package_available("matplotlib")
|
||||||
|
|
||||||
|
|
||||||
|
def is_mcore_adapter_available():
|
||||||
|
return _is_package_available("mcore_adapter")
|
||||||
|
|
||||||
|
|
||||||
def is_pillow_available():
|
def is_pillow_available():
|
||||||
return _is_package_available("PIL")
|
return _is_package_available("PIL")
|
||||||
|
|
||||||
|
|||||||
@ -461,7 +461,7 @@ class FinetuningArguments(
|
|||||||
default="sft",
|
default="sft",
|
||||||
metadata={"help": "Which stage will be performed in training."},
|
metadata={"help": "Which stage will be performed in training."},
|
||||||
)
|
)
|
||||||
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
finetuning_type: Literal["lora", "oft", "freeze", "full"] = field(
|
||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "Which fine-tuning method to use."},
|
metadata={"help": "Which fine-tuning method to use."},
|
||||||
)
|
)
|
||||||
@ -473,6 +473,10 @@ class FinetuningArguments(
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
|
metadata={"help": "Whether or not to use the Adam-mini optimizer."},
|
||||||
)
|
)
|
||||||
|
use_mca: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to use MCA (Megatron Core Adapter) training. Controlled by USE_MCA environment variable."},
|
||||||
|
)
|
||||||
use_muon: bool = field(
|
use_muon: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the Muon optimizer."},
|
metadata={"help": "Whether or not to use the Muon optimizer."},
|
||||||
|
|||||||
@ -32,7 +32,7 @@ from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_availab
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
from ..extras.constants import CHECKPOINT_NAMES, EngineName
|
||||||
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
from ..extras.misc import check_dependencies, check_version, get_current_device, is_env_enabled
|
||||||
from ..extras.packages import is_transformers_version_greater_than
|
from ..extras.packages import is_mcore_adapter_available, is_transformers_version_greater_than
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
from .finetuning_args import FinetuningArguments
|
from .finetuning_args import FinetuningArguments
|
||||||
@ -53,6 +53,13 @@ _INFER_CLS = tuple[ModelArguments, DataArguments, FinetuningArguments, Generatin
|
|||||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
_EVAL_CLS = tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||||
|
|
||||||
|
if is_mcore_adapter_available() and is_env_enabled("USE_MCA"):
|
||||||
|
from mcore_adapter import TrainingArguments as McaTrainingArguments
|
||||||
|
_TRAIN_MCA_ARGS = [ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
|
_TRAIN_MCA_CLS = tuple[ModelArguments, DataArguments, McaTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||||
|
else:
|
||||||
|
_TRAIN_MCA_ARGS = []
|
||||||
|
_TRAIN_MCA_CLS = tuple()
|
||||||
|
|
||||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
||||||
r"""Get arguments from the command line or a config file."""
|
r"""Get arguments from the command line or a config file."""
|
||||||
@ -197,6 +204,27 @@ def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -
|
|||||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
|
||||||
|
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
||||||
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
||||||
|
parser, args, allow_extra_keys=allow_extra_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
_configure_mca_training_args(training_args, data_args, finetuning_args)
|
||||||
|
|
||||||
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_mca_training_args(training_args, data_args, finetuning_args) -> None:
|
||||||
|
"""Patch training args to avoid args checking errors and sync MCA settings."""
|
||||||
|
training_args.predict_with_generate = False
|
||||||
|
training_args.generation_max_length = data_args.cutoff_len
|
||||||
|
training_args.generation_num_beams = 1
|
||||||
|
training_args.use_mca = True
|
||||||
|
finetuning_args.use_mca = True
|
||||||
|
|
||||||
|
|
||||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
||||||
parser = HfArgumentParser(_INFER_ARGS)
|
parser = HfArgumentParser(_INFER_ARGS)
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
@ -216,7 +244,11 @@ def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Ray
|
|||||||
|
|
||||||
|
|
||||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
||||||
|
if is_env_enabled("USE_MCA"):
|
||||||
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
||||||
|
else:
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
|
||||||
|
finetuning_args.use_mca = False
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
if training_args.should_log:
|
if training_args.should_log:
|
||||||
|
|||||||
@ -19,7 +19,20 @@ from typing import Literal, Optional, Union
|
|||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
|
|
||||||
from ..extras.misc import use_ray
|
from ..extras.misc import is_env_enabled, use_ray
|
||||||
|
|
||||||
|
|
||||||
|
if is_env_enabled("USE_MCA"):
|
||||||
|
try:
|
||||||
|
from mcore_adapter import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||||
|
BaseTrainingArguments = McaSeq2SeqTrainingArguments
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"mcore_adapter is required when USE_MCA=1.",
|
||||||
|
"Please install `mcore_adapter` and its dependencies."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
BaseTrainingArguments = Seq2SeqTrainingArguments
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -78,7 +91,7 @@ class RayArguments:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
class TrainingArguments(RayArguments, BaseTrainingArguments):
|
||||||
r"""Arguments pertaining to the trainer."""
|
r"""Arguments pertaining to the trainer."""
|
||||||
|
|
||||||
overwrite_output_dir: bool = field(
|
overwrite_output_dir: bool = field(
|
||||||
@ -87,5 +100,5 @@ class TrainingArguments(RayArguments, Seq2SeqTrainingArguments):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
Seq2SeqTrainingArguments.__post_init__(self)
|
|
||||||
RayArguments.__post_init__(self)
|
RayArguments.__post_init__(self)
|
||||||
|
BaseTrainingArguments.__post_init__(self)
|
||||||
|
|||||||
@ -54,6 +54,10 @@ def launch():
|
|||||||
)
|
)
|
||||||
|
|
||||||
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
command = sys.argv.pop(1) if len(sys.argv) > 1 else "help"
|
||||||
|
if is_env_enabled("USE_MCA"):
|
||||||
|
# force use torchrun
|
||||||
|
os.environ["FORCE_TORCHRUN"] = "1"
|
||||||
|
|
||||||
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
|
||||||
# launch distributed training
|
# launch distributed training
|
||||||
nnodes = os.getenv("NNODES", "1")
|
nnodes = os.getenv("NNODES", "1")
|
||||||
|
|||||||
19
src/llamafactory/train/mca/__init__.py
Normal file
19
src/llamafactory/train/mca/__init__.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .workflow import run_dpo, run_pt, run_sft
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_dpo", "run_pt", "run_sft"]
|
||||||
|
|
||||||
15
src/llamafactory/train/mca/trainer.py
Normal file
15
src/llamafactory/train/mca/trainer.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# TODO override the original trainer
|
||||||
292
src/llamafactory/train/mca/workflow.py
Normal file
292
src/llamafactory/train/mca/workflow.py
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
# Copyright 2025 the ROLL team and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""MCA (mcore_adapter) workflows for PT/SFT/DPO stages, aligned with LLaMA-Factory's workflow style."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from ...data import (
|
||||||
|
SFTDataCollatorWith4DAttentionMask,
|
||||||
|
get_dataset,
|
||||||
|
get_template_and_fix_tokenizer,
|
||||||
|
)
|
||||||
|
from ...data.collator import (
|
||||||
|
PairwiseDataCollatorWithPadding,
|
||||||
|
)
|
||||||
|
from ...extras.constants import IGNORE_INDEX, MCA_SUPPORTED_MODELS
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.misc import calculate_tps
|
||||||
|
from ...extras.packages import is_mcore_adapter_available
|
||||||
|
from ...extras.ploting import plot_loss
|
||||||
|
from ...model import load_tokenizer
|
||||||
|
from ..callbacks import SaveProcessorCallback
|
||||||
|
|
||||||
|
|
||||||
|
if not is_mcore_adapter_available():
|
||||||
|
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||||
|
|
||||||
|
from mcore_adapter.models import AutoConfig, AutoModel
|
||||||
|
from mcore_adapter.trainer import DPOTrainer as McaDPOTrainer
|
||||||
|
from mcore_adapter.trainer import McaTrainer
|
||||||
|
from mcore_adapter.trainer.dpo_config import DPOConfig
|
||||||
|
from mcore_adapter.training_args import Seq2SeqTrainingArguments as McaSeq2SeqTrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import DataCollatorForSeq2Seq, TrainerCallback
|
||||||
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _data_collator_wrapper(data_collator: Any):
|
||||||
|
@functools.wraps(data_collator)
|
||||||
|
def wrapper(features: Sequence[dict[str, Any]]):
|
||||||
|
labels_key = [k for k in features[0].keys() if k.endswith("labels")]
|
||||||
|
input_ids_key = [k for k in features[0].keys() if k.endswith("input_ids")]
|
||||||
|
for feature in features:
|
||||||
|
if len(labels_key) == 0: # pt
|
||||||
|
feature["labels"] = deepcopy(feature["input_ids"])[1:]
|
||||||
|
for k in labels_key:
|
||||||
|
feature[k] = feature[k][1:]
|
||||||
|
for k in input_ids_key:
|
||||||
|
feature[k] = feature[k][:-1]
|
||||||
|
for k in ["attention_mask", "position_ids"]:
|
||||||
|
if k in feature:
|
||||||
|
feature[k] = feature[k][:-1]
|
||||||
|
return data_collator(features)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def _check_model_support(model_args: ModelArguments):
|
||||||
|
from transformers import AutoConfig as HfAutoConfig
|
||||||
|
config = HfAutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code)
|
||||||
|
if config.model_type not in MCA_SUPPORTED_MODELS:
|
||||||
|
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
|
||||||
|
|
||||||
|
def run_pt(
|
||||||
|
model_args: ModelArguments,
|
||||||
|
data_args: DataArguments,
|
||||||
|
training_args: McaSeq2SeqTrainingArguments,
|
||||||
|
finetuning_args: FinetuningArguments,
|
||||||
|
callbacks: list[TrainerCallback] | None = None,
|
||||||
|
):
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
|
||||||
|
# dataset needs +1 then cut back due to MCA shift logic
|
||||||
|
data_args.cutoff_len += 1
|
||||||
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="pt", **tokenizer_module)
|
||||||
|
data_args.cutoff_len -= 1
|
||||||
|
|
||||||
|
_check_model_support(model_args)
|
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
|
||||||
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
|
data_collator: DataCollatorForSeq2Seq = DataCollatorForSeq2Seq(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
pad_to_multiple_of=8,
|
||||||
|
label_pad_token_id=IGNORE_INDEX,
|
||||||
|
)
|
||||||
|
data_collator = _data_collator_wrapper(data_collator)
|
||||||
|
|
||||||
|
trainer = McaTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**dataset_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
|
||||||
|
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
|
||||||
|
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train(training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
keys = ["loss"]
|
||||||
|
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||||
|
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||||
|
else:
|
||||||
|
keys += ["eval_loss"]
|
||||||
|
plot_loss(training_args.output_dir, keys=keys)
|
||||||
|
|
||||||
|
|
||||||
|
def run_sft(
|
||||||
|
model_args: ModelArguments,
|
||||||
|
data_args: DataArguments,
|
||||||
|
training_args: McaSeq2SeqTrainingArguments,
|
||||||
|
finetuning_args: FinetuningArguments,
|
||||||
|
callbacks: list[TrainerCallback] | None = None,
|
||||||
|
):
|
||||||
|
# align packing flags
|
||||||
|
# TODO: FIX SequencePacking
|
||||||
|
data_args.neat_packing = training_args.sequence_packing = data_args.neat_packing or training_args.sequence_packing
|
||||||
|
data_args.packing = data_args.neat_packing or data_args.packing
|
||||||
|
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
|
||||||
|
# dataset needs +1 then cut back due to MCA shift logic
|
||||||
|
data_args.cutoff_len += 1
|
||||||
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
|
data_args.cutoff_len -= 1
|
||||||
|
|
||||||
|
_check_model_support(model_args)
|
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
|
||||||
|
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||||
|
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_vision_tower:
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if any(name.startswith(k) for k in ["vision_model.blocks", "vision_model.patch_embed"]):
|
||||||
|
p.requires_grad_(False)
|
||||||
|
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_multi_modal_projector:
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if any(name.startswith(k) for k in ["multi_modal_projector"]):
|
||||||
|
p.requires_grad_(False)
|
||||||
|
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl"] and finetuning_args.freeze_language_model:
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if any(name.startswith(k) for k in ["embedding", "decoder", "output_layer"]):
|
||||||
|
p.requires_grad_(False)
|
||||||
|
|
||||||
|
pad_to_max = (
|
||||||
|
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||||
|
)
|
||||||
|
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||||
|
template=template,
|
||||||
|
padding="max_length" if pad_to_max else "longest",
|
||||||
|
max_length=data_args.cutoff_len if pad_to_max else None,
|
||||||
|
pad_to_multiple_of=64,
|
||||||
|
label_pad_token_id=IGNORE_INDEX,
|
||||||
|
**tokenizer_module,
|
||||||
|
)
|
||||||
|
data_collator = _data_collator_wrapper(data_collator)
|
||||||
|
|
||||||
|
trainer = McaTrainer(
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**dataset_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
|
||||||
|
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
|
||||||
|
|
||||||
|
train_result = trainer.train(training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
keys = ["loss"]
|
||||||
|
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||||
|
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||||
|
else:
|
||||||
|
keys += ["eval_loss"]
|
||||||
|
plot_loss(training_args.output_dir, keys=keys)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dpo(
|
||||||
|
model_args: ModelArguments,
|
||||||
|
data_args: DataArguments,
|
||||||
|
training_args: McaSeq2SeqTrainingArguments,
|
||||||
|
finetuning_args: FinetuningArguments,
|
||||||
|
callbacks: list[TrainerCallback] | None = None,
|
||||||
|
):
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
|
||||||
|
_check_model_support(model_args)
|
||||||
|
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
|
||||||
|
if finetuning_args.use_ref_model:
|
||||||
|
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
|
||||||
|
ref_model = AutoModel.from_config(ref_config)
|
||||||
|
ref_model.load_state_dict(model.state_dict())
|
||||||
|
else:
|
||||||
|
ref_model = None
|
||||||
|
|
||||||
|
# dataset needs +1 then cut back due to MCA shift logic
|
||||||
|
data_args.cutoff_len += 1
|
||||||
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||||
|
data_args.cutoff_len -= 1
|
||||||
|
|
||||||
|
pad_to_max = (
|
||||||
|
training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||||
|
)
|
||||||
|
dpo_config = DPOConfig(
|
||||||
|
beta=finetuning_args.pref_beta,
|
||||||
|
pref_loss=finetuning_args.pref_loss,
|
||||||
|
label_smoothing=finetuning_args.dpo_label_smoothing,
|
||||||
|
)
|
||||||
|
data_collator = PairwiseDataCollatorWithPadding(
|
||||||
|
template=template,
|
||||||
|
pad_to_multiple_of=64,
|
||||||
|
padding="max_length" if pad_to_max else "longest",
|
||||||
|
max_length=data_args.cutoff_len if pad_to_max else None,
|
||||||
|
label_pad_token_id=IGNORE_INDEX,
|
||||||
|
**tokenizer_module,
|
||||||
|
)
|
||||||
|
data_collator = _data_collator_wrapper(data_collator)
|
||||||
|
|
||||||
|
trainer = McaDPOTrainer(
|
||||||
|
model=model,
|
||||||
|
ref_model=ref_model,
|
||||||
|
args=training_args,
|
||||||
|
train_config=dpo_config,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
**dataset_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "processor" in tokenizer_module and tokenizer_module["processor"] is not None:
|
||||||
|
trainer.add_callback(SaveProcessorCallback(tokenizer_module["processor"]))
|
||||||
|
|
||||||
|
train_result = trainer.train(training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
if finetuning_args.include_effective_tokens_per_second:
|
||||||
|
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||||
|
dataset_module["train_dataset"], train_result.metrics, stage="rm"
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
keys = ["loss", "rewards/accuracies"]
|
||||||
|
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||||
|
keys += [f"eval_{key}_loss" for key in dataset_module["eval_dataset"].keys()]
|
||||||
|
else:
|
||||||
|
keys += ["eval_loss"]
|
||||||
|
|
||||||
|
plot_loss(training_args.output_dir, keys=keys)
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import infer_optim_dtype
|
||||||
from ..extras.packages import is_ray_available
|
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||||
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||||
@ -66,7 +66,19 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
|
|
||||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||||
|
|
||||||
|
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||||
|
if not is_mcore_adapter_available():
|
||||||
|
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||||
if finetuning_args.stage == "pt":
|
if finetuning_args.stage == "pt":
|
||||||
|
from .mca import run_pt as run_pt_mca
|
||||||
|
run_pt_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif finetuning_args.stage == "sft":
|
||||||
|
from .mca import run_sft as run_sft_mca
|
||||||
|
run_sft_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
else: # dpo
|
||||||
|
from .mca import run_dpo as run_dpo_mca
|
||||||
|
run_dpo_mca(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
|
elif finetuning_args.stage == "pt":
|
||||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||||
elif finetuning_args.stage == "sft":
|
elif finetuning_args.stage == "sft":
|
||||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user