11 Commits

Author SHA1 Message Date
Yaowei Zheng
b5cb7cb0e6 [misc] fix constants (#10232) 2026-03-02 11:10:48 +08:00
Philip Ottesen
0779846513 [infer] support mixed multimodal payloads (#10225)
Signed-off-by: Philip Ottesen <phiott256@gmail.com>
2026-02-28 20:26:53 +08:00
jiaqiw09
45d335c709 [v1] add seed for training and fix gradient checkpointing (#10211) 2026-02-28 18:16:06 +08:00
Kingsley
816480012f [fix] register visual part for Qwen3.5 (#10227) 2026-02-28 16:39:24 +08:00
Mikko Tukiainen
d3bf882e87 [docker] upgrade to ROCm 7.2 base image, drop PyTorch reinstall (#10223)
Co-authored-by: Mikko Tukiainen <mtukiain@chi-mi300x-012.ord.vultr.cpe.ice.amd.com>
2026-02-27 20:16:33 +08:00
娄宗志
589da21d32 [model] support Aeva (#10214) 2026-02-26 23:03:13 +08:00
Yaowei Zheng
122cd46084 [model] update constants (#10220) 2026-02-26 21:13:56 +08:00
浮梦
2b8b871475 [model] Adapt Qwen3.5 (#10213)
Co-authored-by: frozenleaves <frozen@Mac.local>
Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn>
2026-02-26 20:45:02 +08:00
Shanay Mehta
aab9b400bb [model] Add DeepSpeed Z3 leaf module for Qwen3-Next (#10194)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 19:54:37 +08:00
P. Clawmogorov
50599c719b [misc] remove safe_serialization arg for transformers v5 compatibility (#10208)
Co-authored-by: P. Clawmogorov <262173731+Alm0stSurely@users.noreply.github.com>
2026-02-24 11:14:19 +08:00
Kingsley
a0f3ad0cee [mca] update supported models (#10196) 2026-02-20 22:02:49 +08:00
30 changed files with 403 additions and 136 deletions

View File

@@ -25,16 +25,16 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: '3.10' python-version: '3.10'
- name: Install dependencies - name: Install dependencies
run: | run: |
pip install -r docs/requirements.txt pip install -r docs/requirements.txt
- name: Build Sphinx - name: Build Sphinx
run: | run: |
sphinx-build -b html docs/zh docs/_build/html/zh sphinx-build -b html docs/zh docs/_build/html/zh
@@ -56,10 +56,10 @@ jobs:
> docs/_build/html/index.html > docs/_build/html/index.html
touch docs/_build/html/.nojekyll touch docs/_build/html/.nojekyll
- name: Setup Pages - name: Setup Pages
uses: actions/configure-pages@v5 uses: actions/configure-pages@v5
- name: Upload artifact - name: Upload artifact
uses: actions/upload-pages-artifact@v3 uses: actions/upload-pages-artifact@v3
with: with:

View File

@@ -291,7 +291,7 @@ Read technical notes:
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small | | [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -319,6 +319,7 @@ Read technical notes:
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 27B/35B/122B/397B | qwen3_5 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |

View File

@@ -293,7 +293,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - | | [GPT-2](https://huggingface.co/openai-community) | 0.1B/0.4B/0.8B/1.5B | - |
| [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss | | [GPT-OSS](https://huggingface.co/openai) | 20B/120B | gpt_oss |
| [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 | | [Granite 3-4](https://huggingface.co/ibm-granite) | 1B/2B/3B/7B/8B | granite3/granite4 |
| [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small | | [Hunyuan/Hunyuan1.5 (MT)](https://huggingface.co/tencent/) | 0.5B/1.8B/4B/7B/13B | hunyuan/hunyuan_small|
| [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 |
| [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl | | [InternVL 2.5-3.5](https://huggingface.co/OpenGVLab) | 1B/2B/4B/8B/14B/30B/38B/78B/241B | intern_vl |
| [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 | | [Intern-S1-mini](https://huggingface.co/internlm/) | 8B | intern_s1 |
@@ -321,6 +321,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink | | [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
| [Qwen3.5](https://huggingface.co/Qwen) | 27B/35B/122B/397B | qwen3_5 |
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni | | [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |

View File

@@ -1,12 +1,12 @@
# https://hub.docker.com/r/rocm/pytorch/tags # https://hub.docker.com/r/rocm/pytorch/tags
ARG BASE_IMAGE=rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0 # ROCm 7.2 + PyTorch 2.7.1 (Python 3.12). Keep base image's PyTorch; do not reinstall.
ARG BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
FROM ${BASE_IMAGE} FROM ${BASE_IMAGE}
# Installation arguments # Installation arguments
ARG PIP_INDEX=https://pypi.org/simple ARG PIP_INDEX=https://pypi.org/simple
ARG INSTALL_FLASHATTN=false ARG INSTALL_FLASHATTN=false
ARG HTTP_PROXY="" ARG HTTP_PROXY=""
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
# Define environments # Define environments
ENV MAX_JOBS=16 ENV MAX_JOBS=16
@@ -32,10 +32,9 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
# Copy the application into the image # Copy the application into the image
COPY . /app COPY . /app
# Reinstall pytorch rocm and install LLaMA Factory # Install LLaMA Factory (use base image's PyTorch/ROCm; do not reinstall)
RUN pip uninstall -y torch torchvision torchaudio && \ RUN pip install --no-cache-dir -e . --pre && \
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \ pip install --no-cache-dir -r requirements/deepspeed.txt -r requirements/liger-kernel.txt -r requirements/bitsandbytes.txt
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
# Rebuild flash attention # Rebuild flash attention
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \ RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \

View File

@@ -47,4 +47,3 @@
border-color: rgba(255, 255, 255, 0.45); border-color: rgba(255, 255, 255, 0.45);
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12); box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
} }

View File

@@ -1,33 +1,31 @@
# Configuration file for the Sphinx documentation builder. # Configuration file for the Sphinx documentation builder.
import os
import sys
# Define common settings here # Define common settings here
project = 'LlamaFactory' project = "LlamaFactory"
copyright = '2024, LlamaFactory Team' copyright = "2024, LlamaFactory Team"
author = 'LlamaFactory Team' author = "LlamaFactory Team"
extensions = [ extensions = [
'sphinx.ext.autodoc', "sphinx.ext.autodoc",
'sphinx.ext.viewcode', "sphinx.ext.viewcode",
'sphinx.ext.napoleon', "sphinx.ext.napoleon",
'myst_parser', "myst_parser",
] ]
templates_path = ['_templates'] templates_path = ["_templates"]
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
html_theme = 'sphinx_rtd_theme' html_theme = "sphinx_rtd_theme"
html_static_path = ['_static'] html_static_path = ["_static"]
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]
html_css_files = [ html_css_files = [
'css/lang-switcher.css', "css/lang-switcher.css",
] ]
myst_enable_extensions = [ myst_enable_extensions = [

View File

@@ -1,20 +1,22 @@
import os import os
import sys import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import * # Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath(".."))
from conf import * # noqa: F403
# Language settings # Language settings
language = 'en' language = "en"
html_search_language = 'en' html_search_language = "en"
# Static files # Static files
# Point to the root _static directory # Point to the root _static directory
html_static_path = ['../_static'] html_static_path = ["../_static"]
# Add custom JS for language switcher # Add custom JS for language switcher
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]

View File

@@ -1,20 +1,22 @@
import os import os
import sys import sys
# Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath('..'))
from conf import * # Add parent dir to path to allow importing conf.py
sys.path.insert(0, os.path.abspath(".."))
from conf import * # noqa: F403
# Language settings # Language settings
language = 'zh_CN' language = "zh_CN"
html_search_language = 'zh' html_search_language = "zh"
# Static files # Static files
# Point to the root _static directory # Point to the root _static directory
html_static_path = ['../_static'] html_static_path = ["../_static"]
# Add custom JS for language switcher # Add custom JS for language switcher
html_js_files = [ html_js_files = [
'js/switcher.js', "js/switcher.js",
] ]

View File

@@ -6,14 +6,14 @@ template: qwen3_nothink
kernel_config: kernel_config:
name: auto name: auto
include_kernels: auto include_kernels: auto
dist_config: dist_config:
name: deepspeed name: deepspeed
config_file: examples/deepspeed/ds_z3_config.json config_file: examples/deepspeed/ds_z3_config.json
### data ### data
train_dataset: data/v1_sft_demo.yaml train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: outputs/Qwen3-0.6B-deepspeed output_dir: outputs/Qwen3-0.6B-deepspeed
@@ -22,4 +22,3 @@ cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: true bf16: true
max_steps: 10 max_steps: 10

View File

@@ -14,16 +14,12 @@ dist_config:
name: fsdp2 name: fsdp2
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
init_config:
name: init_on_meta
### data ### data
train_dataset: data/v1_sft_demo.yaml train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: outputs/test_fsdp2 output_dir: outputs/test_fsdp2
micro_batch_size: 1 micro_batch_size: 1
global_batch_size: 1
cutoff_len: 2048 cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: false bf16: false

View File

@@ -40,7 +40,7 @@ dependencies = [
"torch>=2.4.0", "torch>=2.4.0",
"torchvision>=0.19.0", "torchvision>=0.19.0",
"torchaudio>=2.4.0", "torchaudio>=2.4.0",
"transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0", "transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0",
"datasets>=2.16.0,<=4.0.0", "datasets>=2.16.0,<=4.0.0",
"accelerate>=1.3.0,<=1.11.0", "accelerate>=1.3.0,<=1.11.0",
"peft>=0.18.0,<=0.18.1", "peft>=0.18.0,<=0.18.1",

View File

@@ -154,25 +154,24 @@ def vllm_infer(
batch = train_dataset[i : min(i + batch_size, len(train_dataset))] batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
for j in range(len(batch["input_ids"])): for j in range(len(batch["input_ids"])):
multi_modal_data = {}
video_metadata_kwargs = None
if batch["images"][j] is not None: if batch["images"][j] is not None:
image = batch["images"][j] image = batch["images"][j]
multi_modal_data = { multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
"image": template_obj.mm_plugin._regularize_images( image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels )["images"]
)["images"]
} if batch["videos"][j] is not None:
elif batch["videos"][j] is not None:
video_metadata, video_metadata_kwargs = None, None
video = batch["videos"][j] video = batch["videos"][j]
multi_modal_data = { multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
"video": template_obj.mm_plugin._regularize_videos( video,
video, image_max_pixels=image_max_pixels,
image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels,
image_min_pixels=image_min_pixels, video_fps=video_fps,
video_fps=video_fps, video_maxlen=video_maxlen,
video_maxlen=video_maxlen, )["videos"]
)["videos"]
}
if need_video_kwargs: if need_video_kwargs:
container = av.open(video[0], "r") container = av.open(video[0], "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
@@ -192,18 +191,17 @@ def vllm_infer(
video_backend="opencv", video_backend="opencv",
) )
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata) multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
elif batch["audios"][j] is not None:
if batch["audios"][j] is not None:
audio = batch["audios"][j] audio = batch["audios"][j]
audio_data = template_obj.mm_plugin._regularize_audios( audio_data = template_obj.mm_plugin._regularize_audios(
audio, audio,
sampling_rate=16000, sampling_rate=16000,
) )
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
else:
multi_modal_data = None
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data} vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None: if video_metadata_kwargs is not None:
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
vllm_inputs.append(vllm_input_data) vllm_inputs.append(vllm_input_data)

View File

@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
else self.generating_args["skip_special_tokens"], else self.generating_args["skip_special_tokens"],
) )
multi_modal_data = {}
if images is not None: # add image features if images is not None: # add image features
multi_modal_data = { multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
"image": self.template.mm_plugin._regularize_images( images,
images, image_max_pixels=self.model_args.image_max_pixels,
image_max_pixels=self.model_args.image_max_pixels, image_min_pixels=self.model_args.image_min_pixels,
image_min_pixels=self.model_args.image_min_pixels, )["images"]
)["images"]
} if videos is not None:
elif videos is not None: multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
multi_modal_data = { videos,
"video": self.template.mm_plugin._regularize_videos( image_max_pixels=self.model_args.video_max_pixels,
videos, image_min_pixels=self.model_args.video_min_pixels,
image_max_pixels=self.model_args.video_max_pixels, video_fps=self.model_args.video_fps,
image_min_pixels=self.model_args.video_min_pixels, video_maxlen=self.model_args.video_maxlen,
video_fps=self.model_args.video_fps, )["videos"]
video_maxlen=self.model_args.video_maxlen,
)["videos"] if audios is not None:
}
elif audios is not None:
audio_data = self.template.mm_plugin._regularize_audios( audio_data = self.template.mm_plugin._regularize_audios(
audios, audios,
sampling_rate=self.model_args.audio_sampling_rate, sampling_rate=self.model_args.audio_sampling_rate,
) )
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])} multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
else:
multi_modal_data = None
result_generator = self.model.generate( result_generator = self.model.generate(
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data}, {"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
sampling_params=sampling_params, sampling_params=sampling_params,
request_id=request_id, request_id=request_id,
lora_request=self.lora_request, lora_request=self.lora_request,

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional from typing import TYPE_CHECKING, Any, Literal, Optional
@@ -189,6 +190,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw": mm_inputs.get("video_grid_thw"), "video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": (features["attention_mask"] >= 1).float(), "attention_mask": (features["attention_mask"] >= 1).float(),
} }
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
if "second_per_grid_ts" in mm_inputs: # for qwen2vl if "second_per_grid_ts" in mm_inputs: # for qwen2vl
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
@@ -219,6 +230,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"qwen2_5_vl", "qwen2_5_vl",
"qwen2_5_omni_thinker", "qwen2_5_omni_thinker",
"qwen3_omni_moe_thinker", "qwen3_omni_moe_thinker",
"qwen3_5",
"qwen3_vl", "qwen3_vl",
"qwen3_vl_moe", "qwen3_vl_moe",
] ]

View File

@@ -2029,6 +2029,39 @@ register_template(
) )
register_template(
name="qwen3_5",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
template_class=ReasoningTemplate,
)
register_template(
name="qwen3_5_nothink",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen3_5"),
stop_words=["<|im_end|>"],
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
register_template( register_template(
name="sailor", name="sailor",
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]), format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
@@ -2218,3 +2251,24 @@ register_template(
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
default_system="You are Zephyr, a helpful assistant.", default_system="You are Zephyr, a helpful assistant.",
) )
# copied from glm4_7 template
register_template(
name="aeva",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
format_assistant=StringFormatter(slots=["\n{{content}}"]),
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
format_tools=ToolFormatter(tool_format="glm4_moe"),
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
default_system=(
"You are an AI assistant named Aeva created by Zongzhi Lou. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words=["<|user|>", "<|observation|>"],
thought_words=("<think>", "</think>"),
efficient_eos=True,
template_class=Glm47ReasoningTemplate,
)

View File

@@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = (
""""arguments": <args-json-object>}}\n</tool_call>""" """"arguments": <args-json-object>}}\n</tool_call>"""
) )
QWEN35_TOOL_PROMPT = (
"\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>{tool_text}"
"\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n"
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n"
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n"
"- Function calls MUST follow the specified format: "
"an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n"
"- Required parameters MUST be specified\n"
"- You may provide optional reasoning for your function call in natural language "
"BEFORE the function call, but NOT after\n"
"- If there is no function call available, answer the question like normal with your current knowledge "
"and do not tell the user about function calls\n</IMPORTANT>"
)
SEED_TOOL_PROMPT = ( SEED_TOOL_PROMPT = (
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query." "system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing " "Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
@@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils):
return results return results
class Qwen35ToolUtils(ToolUtils):
r"""Qwen 3.5 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
tool_text = ""
for tool in tools:
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
return QWEN35_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
function_texts = []
for func in functions:
name, arguments = func.name, json.loads(func.arguments)
prompt = f"<tool_call>\n<function={name}>"
for key, value in arguments.items():
prompt += f"\n<parameter={key}>"
if not isinstance(value, str):
value = json.dumps(value, ensure_ascii=False)
prompt += f"\n{value}\n</parameter>"
prompt += "\n</function>\n</tool_call>"
function_texts.append(prompt)
return "\n".join(function_texts)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
results = []
regex = re.compile(r"<tool_call>\s*<function=\s*([^\s<>]+)\s*(.*?)\s*</function>\s*</tool_call>", re.DOTALL)
for func_name, params_block in re.findall(regex, content):
args_dict = {}
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
for key, raw_value in re.findall(param_pattern, params_block.strip()):
value = raw_value.strip()
try:
parsed_value = json.loads(value)
except json.JSONDecodeError:
parsed_value = raw_value.strip()
args_dict[key] = parsed_value
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
return results if results else content
class GLM4MOEToolUtils(QwenToolUtils): class GLM4MOEToolUtils(QwenToolUtils):
r"""GLM-4-MOE tool using template.""" r"""GLM-4-MOE tool using template."""
@@ -662,6 +728,7 @@ TOOLS = {
"minimax2": MiniMaxM2ToolUtils(), "minimax2": MiniMaxM2ToolUtils(),
"mistral": MistralToolUtils(), "mistral": MistralToolUtils(),
"qwen": QwenToolUtils(), "qwen": QwenToolUtils(),
"qwen3_5": Qwen35ToolUtils(),
"glm4_moe": GLM4MOEToolUtils(), "glm4_moe": GLM4MOEToolUtils(),
"seed_oss": SeedToolUtils(), "seed_oss": SeedToolUtils(),
"ling": LingToolUtils(), "ling": LingToolUtils(),

View File

@@ -65,6 +65,7 @@ MCA_SUPPORTED_MODELS = {
"qwen2_vl", "qwen2_vl",
"qwen2_5_vl", "qwen2_5_vl",
"qwen3_vl", "qwen3_vl",
"qwen3_vl_moe",
"qwen3", "qwen3",
"qwen3_moe", "qwen3_moe",
"qwen3_next", "qwen3_next",
@@ -2809,6 +2810,34 @@ register_model_group(
) )
register_model_group(
models={
"Qwen3.5-35B-A3B-Base": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B-Base",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B-Base",
},
"Qwen3.5-27B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
},
"Qwen3.5-35B-A3B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B",
},
"Qwen3.5-122B-A10B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-122B-A10B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-122B-A10B",
},
"Qwen3.5-397B-A17B-Thinking": {
DownloadSource.DEFAULT: "Qwen/Qwen3.5-397B-A17B",
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-397B-A17B",
},
},
template="qwen3_5",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"Qwen2-Audio-7B": { "Qwen2-Audio-7B": {
@@ -3450,3 +3479,35 @@ register_model_group(
}, },
template="zephyr", template="zephyr",
) )
register_model_group(
models={
"Aeva-Flash-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Flash",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Flash",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Flash",
},
"Aeva-Air-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Air",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Air",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Air",
},
"Aeva-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva",
DownloadSource.OPENMIND: "louzongzhi/Aeva",
},
"Aeva-Pro-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Pro",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Pro",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Pro",
},
"Aeva-Max-Chat": {
DownloadSource.DEFAULT: "louzongzhi/Aeva-Max",
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Max",
DownloadSource.OPENMIND: "louzongzhi/Aeva-Max",
},
},
template="aeva",
)

View File

@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None: def check_dependencies() -> None:
r"""Check the version of the required packages.""" r"""Check the version of the required packages."""
check_version("transformers>=4.51.0,<=5.0.0") check_version("transformers>=4.51.0,<=5.2.0")
check_version("datasets>=2.16.0,<=4.0.0") check_version("datasets>=2.16.0,<=4.0.0")
check_version("accelerate>=1.3.0,<=1.11.0") check_version("accelerate>=1.3.0,<=1.11.0")
check_version("peft>=0.18.0,<=0.18.1") check_version("peft>=0.18.0,<=0.18.1")

View File

@@ -142,6 +142,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
if model_type == "qwen3_next":
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef: if not is_trainable or not model_args.moe_aux_loss_coef:

View File

@@ -395,6 +395,24 @@ _register_composite_model(
) )
_register_composite_model(
model_type="qwen3_5",
projector_key="visual.merger",
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model(
model_type="qwen3_5_moe",
projector_key="visual.merger",
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"],
)
_register_composite_model( _register_composite_model(
model_type="video_llava", model_type="video_llava",
) )

View File

@@ -82,7 +82,33 @@ def _check_model_support(model_args: "ModelArguments"):
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
) )
if config.model_type not in MCA_SUPPORTED_MODELS: if config.model_type not in MCA_SUPPORTED_MODELS:
raise ValueError(f"Model {config.model_type} is not supported by MCA.") raise ValueError(
f"Model {config.model_type} is not supported by mcore_adapter."
"You can try to upgrade mcore_adapter to the latest version for more supported models."
)
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe"]:
return
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
params_to_freeze.extend(["vision_model.pos_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
def run_pt( def run_pt(
@@ -161,22 +187,8 @@ def run_sft(
_check_model_support(model_args) _check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
# optional freezing for qwen2_vl, qwen2_5_vl # optional freezing for qwen_vl series
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]: _freeze_model_parameters(model, finetuning_args)
params_to_freeze = []
if finetuning_args.freeze_vision_tower:
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
if finetuning_args.freeze_multi_modal_projector:
params_to_freeze.extend(["multi_modal_projector"])
if finetuning_args.freeze_language_model:
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
if params_to_freeze:
for name, p in model.named_parameters():
if any(name.startswith(k) for k in params_to_freeze):
p.requires_grad_(False)
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1 pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
data_collator = SFTDataCollatorWith4DAttentionMask( data_collator = SFTDataCollatorWith4DAttentionMask(
@@ -229,6 +241,8 @@ def run_dpo(
_check_model_support(model_args) _check_model_support(model_args)
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args) model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
_freeze_model_parameters(model, finetuning_args)
if finetuning_args.use_ref_model: if finetuning_args.use_ref_model:
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args) ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
ref_model = AutoModel.from_config(ref_config) ref_model = AutoModel.from_config(ref_config)

View File

@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
model = model.to(output_dtype) model = model.to(output_dtype)
logger.info_rank0(f"Convert model dtype to: {output_dtype}.") logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
model.save_pretrained( # Prepare save arguments (safe_serialization removed in transformers v5.0.0)
save_directory=model_args.export_dir, save_kwargs = {
max_shard_size=f"{model_args.export_size}GB", "save_directory": model_args.export_dir,
safe_serialization=(not model_args.export_legacy_format), "max_shard_size": f"{model_args.export_size}GB",
) }
if not is_transformers_version_greater_than("5.0.0"):
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.save_pretrained(**save_kwargs)
if model_args.export_hub_model_id is not None: if model_args.export_hub_model_id is not None:
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
push_kwargs = {
"max_shard_size": f"{model_args.export_size}GB",
}
if not is_transformers_version_greater_than("5.0.0"):
push_kwargs["safe_serialization"] = not model_args.export_legacy_format
model.push_to_hub( model.push_to_hub(
model_args.export_hub_model_id, model_args.export_hub_model_id,
token=model_args.hf_hub_token, token=model_args.hf_hub_token,
max_shard_size=f"{model_args.export_size}GB", **push_kwargs,
safe_serialization=(not model_args.export_legacy_format),
) )
if finetuning_args.stage == "rm": if finetuning_args.stage == "rm":

View File

@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
from transformers import HfArgumentParser from transformers import HfArgumentParser
from ..utils.env import is_env_enabled from ..utils.env import is_env_enabled
from ..utils.helper import set_seed
from .data_args import DataArguments from .data_args import DataArguments
from .model_args import ModelArguments from .model_args import ModelArguments
from .sample_args import SampleArguments from .sample_args import SampleArguments
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}") print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}") raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
# Seed as early as possible after argument parsing so all downstream
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
for arg in parsed_args:
seed = getattr(arg, "seed", None)
if seed is not None:
set_seed(seed)
break
return tuple(parsed_args) return tuple(parsed_args)

View File

@@ -66,7 +66,7 @@ class TrainingArguments:
metadata={"help": "Number of workers for batching."}, metadata={"help": "Number of workers for batching."},
) )
enable_activation_checkpointing: bool = field( enable_activation_checkpointing: bool = field(
default=True, default=False,
metadata={"help": "Enable activation checkpointing for training."}, metadata={"help": "Enable activation checkpointing for training."},
) )
dist_config: PluginConfig | None = field( dist_config: PluginConfig | None = field(
@@ -81,6 +81,10 @@ class TrainingArguments:
default=None, default=None,
metadata={"help": "Learning rate scheduler configuration for training."}, metadata={"help": "Learning rate scheduler configuration for training."},
) )
seed: int = field(
default=42,
metadata={"help": "Random seed that will be set at the beginning of training."},
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config) self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -76,7 +76,7 @@ class BaseTrainer:
if self.args.enable_activation_checkpointing: if self.args.enable_activation_checkpointing:
self.model.gradient_checkpointing_enable({"use_reentrant": False}) self.model.gradient_checkpointing_enable({"use_reentrant": False})
self._accelerate_engine = None self._deepspeed_engine = None
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
if dist_name == "deepspeed": if dist_name == "deepspeed":
@@ -108,6 +108,7 @@ class BaseTrainer:
cutoff_len=self.args.cutoff_len, cutoff_len=self.args.cutoff_len,
batching_workers=self.args.batching_workers, batching_workers=self.args.batching_workers,
batching_strategy=self.args.batching_strategy, batching_strategy=self.args.batching_strategy,
seed=self.args.seed,
) )
def _shard_model(self) -> None: def _shard_model(self) -> None:

View File

@@ -26,6 +26,7 @@
from collections.abc import Iterator from collections.abc import Iterator
from typing import Any from typing import Any
import torch
from torch.utils.data import default_collate from torch.utils.data import default_collate
from torchdata.stateful_dataloader import StatefulDataLoader from torchdata.stateful_dataloader import StatefulDataLoader
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL, batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True, pin_memory: bool = True,
drop_last: bool = True, drop_last: bool = True,
seed: int = 42,
) -> None: ) -> None:
self.dataset = dataset self.dataset = dataset
self.renderer = renderer self.renderer = renderer
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
self.batching_strategy = batching_strategy self.batching_strategy = batching_strategy
self.pin_memory = pin_memory self.pin_memory = pin_memory
self.drop_last = drop_last self.drop_last = drop_last
self.seed = seed
# TODO: support length and infinity # TODO: support length and infinity
dp_size = DistributedInterface().get_world_size(Dim.DP) dp_size = DistributedInterface().get_world_size(Dim.DP)
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
num_replicas=DistributedInterface().get_world_size(Dim.DP), num_replicas=DistributedInterface().get_world_size(Dim.DP),
rank=DistributedInterface().get_rank(Dim.DP), rank=DistributedInterface().get_rank(Dim.DP),
shuffle=True, shuffle=True,
seed=0, seed=self.seed,
drop_last=self.drop_last, drop_last=self.drop_last,
) )
else: else:
raise NotImplementedError("Iterable dataset is not supported yet.") raise NotImplementedError("Iterable dataset is not supported yet.")
generato_seed = torch.Generator()
generato_seed.manual_seed(self.seed)
self._data_provider = StatefulDataLoader( self._data_provider = StatefulDataLoader(
self.dataset, self.dataset,
batch_size=self.micro_batch_size * self.num_micro_batch, batch_size=self.micro_batch_size * self.num_micro_batch,
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
pin_memory_device=DistributedInterface().current_device.type, pin_memory_device=DistributedInterface().current_device.type,
drop_last=self.drop_last, drop_last=self.drop_last,
generator=generato_seed,
) )
if self.batching_strategy == BatchingStrategy.NORMAL: if self.batching_strategy == BatchingStrategy.NORMAL:
self._length = len(self._data_provider) self._length = len(self._data_provider)

View File

@@ -166,12 +166,11 @@ class FSDP2Engine:
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None, offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
) )
use_gradient_checkpointing = True # Could be configurable # BaseTrainer is the single source of truth for gradient checkpointing.
if use_gradient_checkpointing: # FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
if getattr(model, "is_gradient_checkpointing", False):
if self.rank == 0: if self.rank == 0:
logger.info("Enabling gradient checkpointing (transformers native)...") logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads() model.enable_input_require_grads()

View File

@@ -15,12 +15,22 @@
import torch import torch
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from transformers import set_seed as hf_set_seed
from ..accelerator.interface import DistributedInterface from ..accelerator.interface import DistributedInterface
from .constants import IGNORE_INDEX from .constants import IGNORE_INDEX
from .types import BatchInput, ModelInput, Processor, Tensor from .types import BatchInput, ModelInput, Processor, Tensor
def set_seed(seed: int) -> None:
"""Set seed for reproducibility.
Args:
seed: Random seed.
"""
hf_set_seed(seed)
def is_tokenizer(processor: Processor) -> bool: def is_tokenizer(processor: Processor) -> bool:
"""Check if processor is tokenizer. """Check if processor is tokenizer.

View File

@@ -22,6 +22,7 @@ from transformers import AutoConfig, AutoModelForImageTextToText
from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.data import get_template_and_fix_tokenizer
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
from llamafactory.extras.constants import IGNORE_INDEX from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.extras.packages import is_transformers_version_greater_than
from llamafactory.hparams import get_infer_args from llamafactory.hparams import get_infer_args
from llamafactory.model import load_tokenizer from llamafactory.model import load_tokenizer
@@ -116,14 +117,16 @@ def test_multimodal_collator():
"labels": [ "labels": [
[0, 1, 2, 3, q, q, q, q, q, q, q, q], [0, 1, 2, 3, q, q, q, q, q, q, q, q],
], ],
"position_ids": [ "position_ids": [[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]]] * 3,
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]], "rope_deltas": [[0]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
],
"rope_deltas": [[-8]],
**tokenizer_module["processor"].image_processor(fake_image), **tokenizer_module["processor"].image_processor(fake_image),
} }
if not is_transformers_version_greater_than("5.0.0"):
# adapt position_ids and rope_deltas for transformers < 5.0.0
# https://github.com/huggingface/transformers/pull/43972
expected_input["position_ids"] = [[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]]] * 3
expected_input["rope_deltas"] = [[-8]]
assert batch_input.keys() == expected_input.keys() assert batch_input.keys() == expected_input.keys()
for k in batch_input.keys(): for k in batch_input.keys():
assert batch_input[k].eq(torch.tensor(expected_input[k])).all() assert batch_input[k].eq(torch.tensor(expected_input[k])).all()

View File

@@ -1,2 +1,2 @@
# change if test fails or cache is outdated # change if test fails or cache is outdated
0.9.5.106 0.9.5.107