mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-13 07:26:00 +08:00
Compare commits
26 Commits
675ce8cc7f
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
246192abd2 | ||
|
|
0258dc14d0 | ||
|
|
3045adf0ba | ||
|
|
a3d44e3152 | ||
|
|
edeb953bc7 | ||
|
|
d045794387 | ||
|
|
9501c3308a | ||
|
|
0ee1c42c2b | ||
|
|
3061f48d55 | ||
|
|
2d9bd2aa14 | ||
|
|
c0245c43fc | ||
|
|
eb976d75a2 | ||
|
|
b5cb7cb0e6 | ||
|
|
0779846513 | ||
|
|
45d335c709 | ||
|
|
816480012f | ||
|
|
d3bf882e87 | ||
|
|
589da21d32 | ||
|
|
122cd46084 | ||
|
|
2b8b871475 | ||
|
|
aab9b400bb | ||
|
|
50599c719b | ||
|
|
a0f3ad0cee | ||
|
|
f80e15dbb4 | ||
|
|
991267fd3b | ||
|
|
5c52afa30d |
1
.github/workflows/tests_cuda.yml
vendored
1
.github/workflows/tests_cuda.yml
vendored
@@ -61,6 +61,7 @@ jobs:
|
||||
uv venv
|
||||
uv pip install -e .
|
||||
uv pip install -r requirements/dev.txt
|
||||
uv pip install -r requirements/bitsandbytes.txt
|
||||
|
||||
- name: Check quality
|
||||
run: |
|
||||
|
||||
@@ -319,6 +319,7 @@ Read technical notes:
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||
@@ -472,7 +473,7 @@ huggingface-cli login
|
||||
|
||||
| Mandatory | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.9 | 3.10 |
|
||||
| python | 3.11 | >=3.11 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
|
||||
@@ -321,6 +321,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
|
||||
| [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral |
|
||||
| [Qwen2 (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen |
|
||||
| [Qwen3 (MoE/Instruct/Thinking/Next)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/80B/235B | qwen3/qwen3_nothink |
|
||||
| [Qwen3.5](https://huggingface.co/Qwen) | 0.8B/2B/4B/9B/27B/35B/122B/397B | qwen3_5 |
|
||||
| [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio |
|
||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni |
|
||||
| [Qwen3-Omni](https://huggingface.co/Qwen) | 30B | qwen3_omni |
|
||||
@@ -474,7 +475,7 @@ huggingface-cli login
|
||||
|
||||
| 必需项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| python | 3.9 | 3.10 |
|
||||
| python | 3.11 | >=3.11 |
|
||||
| torch | 2.0.0 | 2.6.0 |
|
||||
| torchvision | 0.15.0 | 0.21.0 |
|
||||
| transformers | 4.49.0 | 4.50.0 |
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# https://hub.docker.com/r/ascendai/cann/tags
|
||||
|
||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.3.rc2-910b-ubuntu22.04-py3.11
|
||||
ARG BASE_IMAGE=quay.io/ascend/cann:8.5.1-910b-ubuntu22.04-py3.11
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
@@ -33,9 +33,11 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
COPY . /app
|
||||
|
||||
# Install torch-npu
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir "torch==2.7.1" "torch-npu==2.7.1" "torchvision==0.22.1" "torchaudio==2.7.1" --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
RUN source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
RUN pip uninstall -y torch torchvision torchaudio
|
||||
RUN pip install --no-cache-dir -r requirements/npu.txt --index-url "${PYTORCH_INDEX}"
|
||||
RUN pip install --no-cache-dir -r requirements/deepspeed.txt
|
||||
RUN pip install --no-cache-dir -e . --no-build-isolation && \
|
||||
pip install --no-cache-dir -r requirements/metrics.txt --no-build-isolation
|
||||
|
||||
# Set up volumes
|
||||
|
||||
@@ -33,7 +33,7 @@ services:
|
||||
dockerfile: ./docker/docker-npu/Dockerfile
|
||||
context: ../..
|
||||
args:
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.3.rc2-a3-ubuntu22.04-py3.11
|
||||
BASE_IMAGE: quay.io/ascend/cann:8.5.1-a3-ubuntu22.04-py3.11
|
||||
PIP_INDEX: https://pypi.org/simple
|
||||
container_name: llamafactory-a3
|
||||
image: llamafactory:npu-a3
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# https://hub.docker.com/r/rocm/pytorch/tags
|
||||
ARG BASE_IMAGE=rocm/pytorch:rocm6.4.1_ubuntu22.04_py3.10_pytorch_release_2.6.0
|
||||
# ROCm 7.2 + PyTorch 2.7.1 (Python 3.12). Keep base image's PyTorch; do not reinstall.
|
||||
ARG BASE_IMAGE=rocm/pytorch:rocm7.2_ubuntu24.04_py3.12_pytorch_release_2.7.1
|
||||
FROM ${BASE_IMAGE}
|
||||
|
||||
# Installation arguments
|
||||
ARG PIP_INDEX=https://pypi.org/simple
|
||||
ARG INSTALL_FLASHATTN=false
|
||||
ARG HTTP_PROXY=""
|
||||
ARG PYTORCH_INDEX=https://download.pytorch.org/whl/rocm6.3
|
||||
|
||||
# Define environments
|
||||
ENV MAX_JOBS=16
|
||||
@@ -32,10 +32,9 @@ RUN pip config set global.index-url "${PIP_INDEX}" && \
|
||||
# Copy the application into the image
|
||||
COPY . /app
|
||||
|
||||
# Reinstall pytorch rocm and install LLaMA Factory
|
||||
RUN pip uninstall -y torch torchvision torchaudio && \
|
||||
pip install --no-cache-dir --no-build-isolation -e --pre . --index-url "${PYTORCH_INDEX}" && \
|
||||
pip install --no-cache-dir --no-build-isolation -r requirements/metrics.txt -r requirements/deepspeed.txt --index-url "${PYTORCH_INDEX}"
|
||||
# Install LLaMA Factory (use base image's PyTorch/ROCm; do not reinstall)
|
||||
RUN pip install --no-cache-dir -e . --pre && \
|
||||
pip install --no-cache-dir -r requirements/deepspeed.txt -r requirements/liger-kernel.txt -r requirements/bitsandbytes.txt
|
||||
|
||||
# Rebuild flash attention
|
||||
RUN if [ "${INSTALL_FLASHATTN}" == "true" ]; then \
|
||||
|
||||
1
docs/_static/css/lang-switcher.css
vendored
1
docs/_static/css/lang-switcher.css
vendored
@@ -47,4 +47,3 @@
|
||||
border-color: rgba(255, 255, 255, 0.45);
|
||||
box-shadow: 0 0 0 3px rgba(255, 255, 255, 0.12);
|
||||
}
|
||||
|
||||
|
||||
28
docs/conf.py
28
docs/conf.py
@@ -1,33 +1,31 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Define common settings here
|
||||
project = 'LlamaFactory'
|
||||
copyright = '2024, LlamaFactory Team'
|
||||
author = 'LlamaFactory Team'
|
||||
project = "LlamaFactory"
|
||||
copyright = "2024, LlamaFactory Team"
|
||||
author = "LlamaFactory Team"
|
||||
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.viewcode',
|
||||
'sphinx.ext.napoleon',
|
||||
'myst_parser',
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx.ext.napoleon",
|
||||
"myst_parser",
|
||||
]
|
||||
|
||||
templates_path = ['_templates']
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
html_theme = "sphinx_rtd_theme"
|
||||
|
||||
html_static_path = ['_static']
|
||||
html_static_path = ["_static"]
|
||||
|
||||
html_js_files = [
|
||||
'js/switcher.js',
|
||||
"js/switcher.js",
|
||||
]
|
||||
|
||||
html_css_files = [
|
||||
'css/lang-switcher.css',
|
||||
"css/lang-switcher.css",
|
||||
]
|
||||
|
||||
myst_enable_extensions = [
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add parent dir to path to allow importing conf.py
|
||||
sys.path.insert(0, os.path.abspath('..'))
|
||||
|
||||
from conf import *
|
||||
# Add parent dir to path to allow importing conf.py
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
from conf import * # noqa: F403
|
||||
|
||||
|
||||
# Language settings
|
||||
language = 'en'
|
||||
html_search_language = 'en'
|
||||
language = "en"
|
||||
html_search_language = "en"
|
||||
|
||||
# Static files
|
||||
# Point to the root _static directory
|
||||
html_static_path = ['../_static']
|
||||
html_static_path = ["../_static"]
|
||||
|
||||
# Add custom JS for language switcher
|
||||
html_js_files = [
|
||||
'js/switcher.js',
|
||||
"js/switcher.js",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,22 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add parent dir to path to allow importing conf.py
|
||||
sys.path.insert(0, os.path.abspath('..'))
|
||||
|
||||
from conf import *
|
||||
# Add parent dir to path to allow importing conf.py
|
||||
sys.path.insert(0, os.path.abspath(".."))
|
||||
|
||||
from conf import * # noqa: F403
|
||||
|
||||
|
||||
# Language settings
|
||||
language = 'zh_CN'
|
||||
html_search_language = 'zh'
|
||||
language = "zh_CN"
|
||||
html_search_language = "zh"
|
||||
|
||||
# Static files
|
||||
# Point to the root _static directory
|
||||
html_static_path = ['../_static']
|
||||
html_static_path = ["../_static"]
|
||||
|
||||
# Add custom JS for language switcher
|
||||
html_js_files = [
|
||||
'js/switcher.js',
|
||||
"js/switcher.js",
|
||||
]
|
||||
|
||||
24
examples/v1/train_full/train_full_deepspeed.yaml
Normal file
24
examples/v1/train_full/train_full_deepspeed.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
dist_config:
|
||||
name: deepspeed
|
||||
config_file: examples/deepspeed/ds_z3_config.json
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/Qwen3-0.6B-deepspeed
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: true
|
||||
max_steps: 10
|
||||
@@ -14,16 +14,12 @@ dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null # /mnt/f/pretrain_models/Qwen3-0.6B-dcp
|
||||
|
||||
init_config:
|
||||
name: init_on_meta
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_fsdp2
|
||||
micro_batch_size: 1
|
||||
global_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
|
||||
43
examples/v1/train_qlora/quantization.yaml
Normal file
43
examples/v1/train_qlora/quantization.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
model: Qwen/Qwen3-0.6B
|
||||
trust_remote_code: true
|
||||
model_class: llm
|
||||
|
||||
template: qwen3_nothink
|
||||
|
||||
# PEFT Configuration
|
||||
peft_config:
|
||||
name: lora
|
||||
r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
target_modules: all
|
||||
|
||||
# Kernel Config
|
||||
kernel_config:
|
||||
name: auto
|
||||
include_kernels: auto
|
||||
|
||||
# FSDP Config
|
||||
dist_config:
|
||||
name: fsdp2
|
||||
dcp_path: null
|
||||
|
||||
# Quantization Config
|
||||
quant_config:
|
||||
name: bnb # choice: auto/bnb if auto is selected, the quantization method will be automatically selected based on the model and environment.
|
||||
quantization_bit: 4 # choice: 8/4(bnb)
|
||||
|
||||
### data
|
||||
train_dataset: data/v1_sft_demo.yaml
|
||||
|
||||
### training
|
||||
output_dir: outputs/test_quantization
|
||||
micro_batch_size: 1
|
||||
cutoff_len: 2048
|
||||
learning_rate: 1.0e-4
|
||||
bf16: false
|
||||
max_steps: 10
|
||||
|
||||
### sample
|
||||
sample_backend: hf
|
||||
max_new_tokens: 128
|
||||
@@ -40,7 +40,7 @@ dependencies = [
|
||||
"torch>=2.4.0",
|
||||
"torchvision>=0.19.0",
|
||||
"torchaudio>=2.4.0",
|
||||
"transformers>=4.51.0,<=5.0.0,!=4.52.0,!=4.57.0",
|
||||
"transformers>=4.51.0,<=5.2.0,!=4.52.0,!=4.57.0",
|
||||
"datasets>=2.16.0,<=4.0.0",
|
||||
"accelerate>=1.3.0,<=1.11.0",
|
||||
"peft>=0.18.0,<=0.18.1",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
torch==2.7.1
|
||||
torch-npu==2.7.1
|
||||
torch-npu==2.7.1.post2
|
||||
torchvision==0.22.1
|
||||
torchaudio==2.7.1
|
||||
|
||||
@@ -71,6 +71,7 @@ def convert(
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
expert_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: int | None = None,
|
||||
moe_grouped_gemm: bool | None = None,
|
||||
):
|
||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||
|
||||
@@ -84,6 +85,10 @@ def convert(
|
||||
pipeline_model_parallel_size: Pipeline model parallel size
|
||||
expert_model_parallel_size: Expert model parallel size
|
||||
virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size
|
||||
moe_grouped_gemm: Use grouped gemm for MoE experts. When enabled, expert
|
||||
weights are stored in a flattened format (linear_fc1.weight0, weight1, ...)
|
||||
rather than per-expert format (local_experts.0.linear_fc1.weight, ...).
|
||||
Must match the format used when saving the checkpoint.
|
||||
"""
|
||||
if bf16 and fp16:
|
||||
raise ValueError("bf16 and fp16 cannot be both True.")
|
||||
@@ -97,8 +102,9 @@ def convert(
|
||||
pipeline_model_parallel_size=pipeline_model_parallel_size,
|
||||
expert_model_parallel_size=expert_model_parallel_size,
|
||||
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
|
||||
moe_grouped_gemm=moe_grouped_gemm,
|
||||
transformer_impl="transformer_engine", # hard code here since we default using te for training
|
||||
)
|
||||
|
||||
convert_checkpoint_to_mca(
|
||||
checkpoint_path,
|
||||
output_path,
|
||||
|
||||
@@ -154,25 +154,24 @@ def vllm_infer(
|
||||
batch = train_dataset[i : min(i + batch_size, len(train_dataset))]
|
||||
|
||||
for j in range(len(batch["input_ids"])):
|
||||
multi_modal_data = {}
|
||||
video_metadata_kwargs = None
|
||||
|
||||
if batch["images"][j] is not None:
|
||||
image = batch["images"][j]
|
||||
multi_modal_data = {
|
||||
"image": template_obj.mm_plugin._regularize_images(
|
||||
multi_modal_data["image"] = template_obj.mm_plugin._regularize_images(
|
||||
image, image_max_pixels=image_max_pixels, image_min_pixels=image_min_pixels
|
||||
)["images"]
|
||||
}
|
||||
elif batch["videos"][j] is not None:
|
||||
video_metadata, video_metadata_kwargs = None, None
|
||||
|
||||
if batch["videos"][j] is not None:
|
||||
video = batch["videos"][j]
|
||||
multi_modal_data = {
|
||||
"video": template_obj.mm_plugin._regularize_videos(
|
||||
multi_modal_data["video"] = template_obj.mm_plugin._regularize_videos(
|
||||
video,
|
||||
image_max_pixels=image_max_pixels,
|
||||
image_min_pixels=image_min_pixels,
|
||||
video_fps=video_fps,
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
if need_video_kwargs:
|
||||
container = av.open(video[0], "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
@@ -192,18 +191,17 @@ def vllm_infer(
|
||||
video_backend="opencv",
|
||||
)
|
||||
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||
elif batch["audios"][j] is not None:
|
||||
|
||||
if batch["audios"][j] is not None:
|
||||
audio = batch["audios"][j]
|
||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||
audio,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
||||
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data or None}
|
||||
if video_metadata_kwargs is not None:
|
||||
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||
|
||||
vllm_inputs.append(vllm_input_data)
|
||||
|
||||
@@ -180,35 +180,32 @@ class VllmEngine(BaseEngine):
|
||||
else self.generating_args["skip_special_tokens"],
|
||||
)
|
||||
|
||||
multi_modal_data = {}
|
||||
if images is not None: # add image features
|
||||
multi_modal_data = {
|
||||
"image": self.template.mm_plugin._regularize_images(
|
||||
multi_modal_data["image"] = self.template.mm_plugin._regularize_images(
|
||||
images,
|
||||
image_max_pixels=self.model_args.image_max_pixels,
|
||||
image_min_pixels=self.model_args.image_min_pixels,
|
||||
)["images"]
|
||||
}
|
||||
elif videos is not None:
|
||||
multi_modal_data = {
|
||||
"video": self.template.mm_plugin._regularize_videos(
|
||||
|
||||
if videos is not None:
|
||||
multi_modal_data["video"] = self.template.mm_plugin._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=self.model_args.video_max_pixels,
|
||||
image_min_pixels=self.model_args.video_min_pixels,
|
||||
video_fps=self.model_args.video_fps,
|
||||
video_maxlen=self.model_args.video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
elif audios is not None:
|
||||
|
||||
if audios is not None:
|
||||
audio_data = self.template.mm_plugin._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=self.model_args.audio_sampling_rate,
|
||||
)
|
||||
multi_modal_data = {"audio": zip(audio_data["audios"], audio_data["sampling_rates"])}
|
||||
else:
|
||||
multi_modal_data = None
|
||||
multi_modal_data["audio"] = zip(audio_data["audios"], audio_data["sampling_rates"])
|
||||
|
||||
result_generator = self.model.generate(
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data},
|
||||
{"prompt_token_ids": prompt_ids, "multi_modal_data": multi_modal_data or None},
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
lora_request=self.lora_request,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
@@ -189,6 +190,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"video_grid_thw": mm_inputs.get("video_grid_thw"),
|
||||
"attention_mask": (features["attention_mask"] >= 1).float(),
|
||||
}
|
||||
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
|
||||
image_token_id = getattr(self.model.config, "image_token_id", None)
|
||||
video_token_id = getattr(self.model.config, "video_token_id", None)
|
||||
if image_token_id is not None or video_token_id is not None:
|
||||
mm_token_type_ids = torch.zeros_like(features["input_ids"])
|
||||
if image_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
|
||||
if video_token_id is not None:
|
||||
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
|
||||
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
|
||||
if "second_per_grid_ts" in mm_inputs: # for qwen2vl
|
||||
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
|
||||
elif "video_second_per_grid" in mm_inputs: # for qwen2.5 omni
|
||||
@@ -219,6 +230,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"qwen2_5_vl",
|
||||
"qwen2_5_omni_thinker",
|
||||
"qwen3_omni_moe_thinker",
|
||||
"qwen3_5",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
]
|
||||
|
||||
@@ -196,7 +196,7 @@ def read_cloud_json(cloud_path: str) -> list[Any]:
|
||||
|
||||
# filter out non-JSON files
|
||||
files = [x["Key"] for x in fs.listdir(cloud_path)] if fs.isdir(cloud_path) else [cloud_path]
|
||||
files = filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files)
|
||||
files = list(filter(lambda file: file.endswith(".json") or file.endswith(".jsonl"), files))
|
||||
if not files:
|
||||
raise ValueError(f"No JSON/JSONL files found in the specified path: {cloud_path}.")
|
||||
|
||||
|
||||
@@ -161,7 +161,9 @@ class MMPluginMixin:
|
||||
video_processor: BaseImageProcessor = getattr(
|
||||
processor, "video_processor", getattr(processor, "image_processor", None)
|
||||
)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
"This model does not support image input. Please check whether the correct `template` is used."
|
||||
@@ -390,7 +392,9 @@ class MMPluginMixin:
|
||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||
|
||||
if len(audios) != 0:
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
@@ -1876,7 +1880,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
) -> dict[str, "torch.Tensor"]:
|
||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor", None)
|
||||
video_processor: BaseVideoProcessor = getattr(processor, "video_processor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None)
|
||||
feature_extractor: SequenceFeatureExtractor = getattr(processor, "feature_extractor", None) or getattr(
|
||||
processor, "audio_processor", None
|
||||
)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
|
||||
@@ -1113,7 +1113,7 @@ register_template(
|
||||
register_template(
|
||||
name="gpt_oss",
|
||||
format_user=StringFormatter(slots=["<|start|>user<|message|>{{content}}<|end|><|start|>assistant"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|end|>"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|start|>system<|message|>{{content}}<|end|>"]),
|
||||
default_system="You are ChatGPT, a large language model trained by OpenAI.",
|
||||
thought_words=("<|channel|>analysis<|message|>", "<|end|><|start|>assistant<|channel|>final<|message|>"),
|
||||
@@ -2029,6 +2029,39 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="qwen3_5",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen3_5"),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
template_class=ReasoningTemplate,
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="qwen3_5_nothink",
|
||||
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
|
||||
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen3_5"),
|
||||
format_observation=StringFormatter(
|
||||
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="qwen3_5"),
|
||||
stop_words=["<|im_end|>"],
|
||||
replace_eos=True,
|
||||
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="sailor",
|
||||
format_user=StringFormatter(slots=["<|im_start|>question\n{{content}}<|im_end|>\n<|im_start|>answer\n"]),
|
||||
@@ -2218,3 +2251,24 @@ register_template(
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
||||
default_system="You are Zephyr, a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
# copied from glm4_7 template
|
||||
register_template(
|
||||
name="aeva",
|
||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
||||
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
|
||||
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
||||
format_tools=ToolFormatter(tool_format="glm4_moe"),
|
||||
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
||||
default_system=(
|
||||
"You are an AI assistant named Aeva created by Zongzhi Lou. "
|
||||
"Your answer should be friendly, unbiased, faithful, informative and detailed."
|
||||
),
|
||||
stop_words=["<|user|>", "<|observation|>"],
|
||||
thought_words=("<think>", "</think>"),
|
||||
efficient_eos=True,
|
||||
template_class=Glm47ReasoningTemplate,
|
||||
)
|
||||
|
||||
@@ -85,6 +85,21 @@ QWEN_TOOL_PROMPT = (
|
||||
""""arguments": <args-json-object>}}\n</tool_call>"""
|
||||
)
|
||||
|
||||
QWEN35_TOOL_PROMPT = (
|
||||
"\n\n# Tools\n\nYou have access to the following functions:\n\n<tools>{tool_text}"
|
||||
"\n</tools>\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n"
|
||||
"<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n"
|
||||
"<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n"
|
||||
"</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n"
|
||||
"- Function calls MUST follow the specified format: "
|
||||
"an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n"
|
||||
"- Required parameters MUST be specified\n"
|
||||
"- You may provide optional reasoning for your function call in natural language "
|
||||
"BEFORE the function call, but NOT after\n"
|
||||
"- If there is no function call available, answer the question like normal with your current knowledge "
|
||||
"and do not tell the user about function calls\n</IMPORTANT>"
|
||||
)
|
||||
|
||||
SEED_TOOL_PROMPT = (
|
||||
"system\nYou are Doubao, a helpful AI assistant. You may call one or more functions to assist with the user query."
|
||||
"Tool List:\nYou are authorized to use the following tools (described in JSON Schema format). Before performing "
|
||||
@@ -453,6 +468,57 @@ class QwenToolUtils(ToolUtils):
|
||||
return results
|
||||
|
||||
|
||||
class Qwen35ToolUtils(ToolUtils):
|
||||
r"""Qwen 3.5 tool using template."""
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
for tool in tools:
|
||||
tool = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||
tool_text += "\n" + json.dumps(tool, ensure_ascii=False)
|
||||
|
||||
return QWEN35_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||
function_texts = []
|
||||
for func in functions:
|
||||
name, arguments = func.name, json.loads(func.arguments)
|
||||
prompt = f"<tool_call>\n<function={name}>"
|
||||
for key, value in arguments.items():
|
||||
prompt += f"\n<parameter={key}>"
|
||||
if not isinstance(value, str):
|
||||
value = json.dumps(value, ensure_ascii=False)
|
||||
prompt += f"\n{value}\n</parameter>"
|
||||
prompt += "\n</function>\n</tool_call>"
|
||||
function_texts.append(prompt)
|
||||
|
||||
return "\n".join(function_texts)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||
results = []
|
||||
regex = re.compile(r"<tool_call>\s*<function=\s*([^\s<>]+)\s*(.*?)\s*</function>\s*</tool_call>", re.DOTALL)
|
||||
for func_name, params_block in re.findall(regex, content):
|
||||
args_dict = {}
|
||||
param_pattern = re.compile(r"<parameter=(.*?)>(.*?)</parameter>", re.DOTALL)
|
||||
for key, raw_value in re.findall(param_pattern, params_block.strip()):
|
||||
value = raw_value.strip()
|
||||
try:
|
||||
parsed_value = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
parsed_value = raw_value.strip()
|
||||
args_dict[key] = parsed_value
|
||||
|
||||
results.append(FunctionCall(func_name.strip(), json.dumps(args_dict, ensure_ascii=False)))
|
||||
|
||||
return results if results else content
|
||||
|
||||
|
||||
class GLM4MOEToolUtils(QwenToolUtils):
|
||||
r"""GLM-4-MOE tool using template."""
|
||||
|
||||
@@ -662,6 +728,7 @@ TOOLS = {
|
||||
"minimax2": MiniMaxM2ToolUtils(),
|
||||
"mistral": MistralToolUtils(),
|
||||
"qwen": QwenToolUtils(),
|
||||
"qwen3_5": Qwen35ToolUtils(),
|
||||
"glm4_moe": GLM4MOEToolUtils(),
|
||||
"seed_oss": SeedToolUtils(),
|
||||
"ling": LingToolUtils(),
|
||||
|
||||
@@ -65,9 +65,12 @@ MCA_SUPPORTED_MODELS = {
|
||||
"qwen2_vl",
|
||||
"qwen2_5_vl",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"qwen3",
|
||||
"qwen3_moe",
|
||||
"qwen3_next",
|
||||
"qwen3_5",
|
||||
"qwen3_5_moe",
|
||||
}
|
||||
|
||||
METHODS = ["full", "freeze", "lora", "oft"]
|
||||
@@ -2809,6 +2812,66 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen3.5-0.8B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B-Base",
|
||||
},
|
||||
"Qwen3.5-2B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B-Base",
|
||||
},
|
||||
"Qwen3.5-4B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B-Base",
|
||||
},
|
||||
"Qwen3.5-9B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B-Base",
|
||||
},
|
||||
"Qwen3.5-35B-A3B-Base": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B-Base",
|
||||
},
|
||||
"Qwen3.5-0.8B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-0.8B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-0.8B",
|
||||
},
|
||||
"Qwen3.5-2B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-2B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-2B",
|
||||
},
|
||||
"Qwen3.5-4B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-4B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-4B",
|
||||
},
|
||||
"Qwen3.5-9B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-9B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-9B",
|
||||
},
|
||||
"Qwen3.5-27B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-27B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-27B",
|
||||
},
|
||||
"Qwen3.5-35B-A3B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-35B-A3B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-35B-A3B",
|
||||
},
|
||||
"Qwen3.5-122B-A10B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-122B-A10B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-122B-A10B",
|
||||
},
|
||||
"Qwen3.5-397B-A17B-Thinking": {
|
||||
DownloadSource.DEFAULT: "Qwen/Qwen3.5-397B-A17B",
|
||||
DownloadSource.MODELSCOPE: "Qwen/Qwen3.5-397B-A17B",
|
||||
},
|
||||
},
|
||||
template="qwen3_5",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Qwen2-Audio-7B": {
|
||||
@@ -3450,3 +3513,35 @@ register_model_group(
|
||||
},
|
||||
template="zephyr",
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Aeva-Flash-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva-Flash",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Flash",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva-Flash",
|
||||
},
|
||||
"Aeva-Air-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva-Air",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Air",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva-Air",
|
||||
},
|
||||
"Aeva-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva",
|
||||
},
|
||||
"Aeva-Pro-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva-Pro",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Pro",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva-Pro",
|
||||
},
|
||||
"Aeva-Max-Chat": {
|
||||
DownloadSource.DEFAULT: "louzongzhi/Aeva-Max",
|
||||
DownloadSource.MODELSCOPE: "louzongktsi/Aeva-Max",
|
||||
DownloadSource.OPENMIND: "louzongzhi/Aeva-Max",
|
||||
},
|
||||
},
|
||||
template="aeva",
|
||||
)
|
||||
|
||||
@@ -94,7 +94,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.51.0,<=5.0.0")
|
||||
check_version("transformers>=4.51.0,<=5.2.0")
|
||||
check_version("datasets>=2.16.0,<=4.0.0")
|
||||
check_version("accelerate>=1.3.0,<=1.11.0")
|
||||
check_version("peft>=0.18.0,<=0.18.1")
|
||||
|
||||
@@ -100,6 +100,52 @@ def _parse_args(
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
def _verify_trackio_args(training_args: "TrainingArguments") -> None:
|
||||
"""Validates Trackio-specific arguments.
|
||||
|
||||
Args:
|
||||
training_args: TrainingArguments instance (not a dictionary)
|
||||
"""
|
||||
report_to = training_args.report_to
|
||||
if not report_to:
|
||||
return
|
||||
|
||||
if isinstance(report_to, str):
|
||||
report_to = [report_to]
|
||||
|
||||
if "trackio" not in report_to:
|
||||
return
|
||||
|
||||
# --- Enforce project (required by Trackio) ---
|
||||
if not training_args.project:
|
||||
raise ValueError("`--project` must be specified when using Trackio.")
|
||||
|
||||
# --- Validate trackio_space_id format ---
|
||||
space_id = training_args.trackio_space_id
|
||||
if space_id:
|
||||
if space_id != "trackio" and "/" not in space_id:
|
||||
logger.warning(
|
||||
f"trackio_space_id '{space_id}' should typically be in format "
|
||||
"'org/space' for Hugging Face Spaces deployment."
|
||||
)
|
||||
|
||||
# --- Inform about default project usage ---
|
||||
if training_args.project == "huggingface":
|
||||
logger.info(
|
||||
"Using default project name 'huggingface'. "
|
||||
"Consider setting a custom project name with --project "
|
||||
"for better organization."
|
||||
)
|
||||
|
||||
# --- Validate hub repo privacy flag ---
|
||||
if training_args.hub_private_repo:
|
||||
logger.info("Repository will be created as private on Hugging Face Hub.")
|
||||
|
||||
# --- Recommend run_name for experiment clarity ---
|
||||
if not training_args.run_name:
|
||||
logger.warning("Consider setting --run_name for better experiment tracking clarity.")
|
||||
|
||||
|
||||
def _set_transformers_logging() -> None:
|
||||
if os.getenv("LLAMAFACTORY_VERBOSITY", "INFO") in ["DEBUG", "INFO"]:
|
||||
transformers.utils.logging.set_verbosity_info()
|
||||
@@ -278,8 +324,10 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
if finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support lora reward model.")
|
||||
|
||||
if training_args.report_to and training_args.report_to[0] not in ["wandb", "tensorboard"]:
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
if training_args.report_to and any(
|
||||
logger not in ("wandb", "tensorboard", "trackio", "none") for logger in training_args.report_to
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb, tensorboard, or trackio logger.")
|
||||
|
||||
if not model_args.use_kt and training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
@@ -352,6 +400,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
_set_env_vars()
|
||||
_verify_model_args(model_args, data_args, finetuning_args)
|
||||
_check_extra_dependencies(model_args, finetuning_args, training_args)
|
||||
_verify_trackio_args(training_args)
|
||||
|
||||
if not finetuning_args.use_mca and training_args.fp8_enable_fsdp_float8_all_gather and not training_args.fp8:
|
||||
logger.warning_rank0("fp8_enable_fsdp_float8_all_gather requires fp8=True. Setting fp8=True.")
|
||||
@@ -421,7 +470,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
|
||||
@@ -142,6 +142,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3OmniMoeThinkerTextSparseMoeBlock])
|
||||
|
||||
if model_type == "qwen3_next":
|
||||
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
|
||||
|
||||
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
|
||||
|
||||
|
||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
if not is_trainable or not model_args.moe_aux_loss_coef:
|
||||
|
||||
@@ -390,7 +390,25 @@ _register_composite_model(
|
||||
"visual.deepstack_merger_list",
|
||||
"audio_tower",
|
||||
],
|
||||
language_model_keys=["model", "lm_head"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5",
|
||||
projector_key="model.visual.merger",
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="qwen3_5_moe",
|
||||
projector_key="model.visual.merger",
|
||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||
language_model_keys=["language_model", "lm_head"],
|
||||
lora_conflict_keys=["patch_embed"],
|
||||
)
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ class LogCallback(TrainerCallback):
|
||||
if (
|
||||
args.should_save
|
||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
and args.overwrite_output_dir
|
||||
and getattr(args, "overwrite_output_dir", False)
|
||||
):
|
||||
logger.warning_rank0_once("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
@@ -371,6 +371,18 @@ class ReporterCallback(TrainerCallback):
|
||||
}
|
||||
)
|
||||
|
||||
if "trackio" in args.report_to:
|
||||
import trackio
|
||||
|
||||
trackio.config.update(
|
||||
{
|
||||
"model_args": self.model_args.to_dict(),
|
||||
"data_args": self.data_args.to_dict(),
|
||||
"finetuning_args": self.finetuning_args.to_dict(),
|
||||
"generating_args": self.generating_args.to_dict(),
|
||||
}
|
||||
)
|
||||
|
||||
if self.finetuning_args.use_swanlab:
|
||||
import swanlab # type: ignore
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
@@ -77,12 +79,43 @@ def _data_collator_wrapper(data_collator: Any):
|
||||
|
||||
def _check_model_support(model_args: "ModelArguments"):
|
||||
from transformers import AutoConfig as HfAutoConfig
|
||||
|
||||
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
|
||||
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
|
||||
model_type = mca_config.get("hf_model_type", None)
|
||||
else:
|
||||
config = HfAutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
if config.model_type not in MCA_SUPPORTED_MODELS:
|
||||
raise ValueError(f"Model {config.model_type} is not supported by MCA.")
|
||||
model_type = config.model_type
|
||||
|
||||
if model_type not in MCA_SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model {model_type} is not supported by mcore_adapter."
|
||||
"You can try to upgrade mcore_adapter to the latest version for more supported models."
|
||||
)
|
||||
|
||||
|
||||
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
||||
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
||||
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
|
||||
return
|
||||
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen3_vl", "qwen3_vl_moe"]:
|
||||
params_to_freeze.extend(["vision_model.pos_embed"])
|
||||
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
params_to_freeze.extend(["multi_modal_projector"])
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||
|
||||
if params_to_freeze:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in params_to_freeze):
|
||||
p.requires_grad_(False)
|
||||
|
||||
|
||||
def run_pt(
|
||||
@@ -161,22 +194,8 @@ def run_sft(
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
# optional freezing for qwen2_vl, qwen2_5_vl
|
||||
if getattr(model.config, "hf_model_type", None) in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl"]:
|
||||
params_to_freeze = []
|
||||
if finetuning_args.freeze_vision_tower:
|
||||
params_to_freeze.extend(["vision_model.blocks", "vision_model.patch_embed"])
|
||||
|
||||
if finetuning_args.freeze_multi_modal_projector:
|
||||
params_to_freeze.extend(["multi_modal_projector"])
|
||||
|
||||
if finetuning_args.freeze_language_model:
|
||||
params_to_freeze.extend(["embedding", "decoder", "output_layer"])
|
||||
|
||||
if params_to_freeze:
|
||||
for name, p in model.named_parameters():
|
||||
if any(name.startswith(k) for k in params_to_freeze):
|
||||
p.requires_grad_(False)
|
||||
# optional freezing for qwen_vl series
|
||||
_freeze_model_parameters(model, finetuning_args)
|
||||
|
||||
pad_to_max = training_args.expert_model_parallel_size is not None and training_args.expert_model_parallel_size > 1
|
||||
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||
@@ -229,6 +248,8 @@ def run_dpo(
|
||||
_check_model_support(model_args)
|
||||
model = AutoModel.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
|
||||
_freeze_model_parameters(model, finetuning_args)
|
||||
|
||||
if finetuning_args.use_ref_model:
|
||||
ref_config = AutoConfig.from_pretrained(model_args.model_name_or_path, training_args)
|
||||
ref_model = AutoModel.from_config(ref_config)
|
||||
|
||||
@@ -215,7 +215,13 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
if len(pad_len): # move pad token to last
|
||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||
|
||||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
||||
input_ids_column = dataset["input_ids"]
|
||||
try:
|
||||
input_ids_list = input_ids_column.to_pylist()
|
||||
except AttributeError:
|
||||
input_ids_list = list(input_ids_column)
|
||||
|
||||
decoded_inputs = self.processing_class.batch_decode(input_ids_list, skip_special_tokens=False)
|
||||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens)
|
||||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ if is_apollo_available():
|
||||
|
||||
if is_ray_available():
|
||||
import ray
|
||||
from ray.util.state import list_nodes
|
||||
from ray.util.placement_group import PlacementGroup, placement_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
@@ -941,7 +942,7 @@ def get_ray_remote_config_for_worker(
|
||||
|
||||
def get_ray_head_node_ip() -> str:
|
||||
r"""Get the IP address of the Ray head node."""
|
||||
head_ip = next(node["NodeManagerAddress"] for node in ray.nodes() if node.get("IsHead", False))
|
||||
head_ip = next(node["node_ip"] for node in list_nodes() if node.get("is_head_node", False))
|
||||
return head_ip
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available
|
||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
|
||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||
@@ -160,17 +160,28 @@ def export_model(args: Optional[dict[str, Any]] = None) -> None:
|
||||
model = model.to(output_dtype)
|
||||
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
|
||||
|
||||
model.save_pretrained(
|
||||
save_directory=model_args.export_dir,
|
||||
max_shard_size=f"{model_args.export_size}GB",
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
)
|
||||
# Prepare save arguments (safe_serialization removed in transformers v5.0.0)
|
||||
save_kwargs = {
|
||||
"save_directory": model_args.export_dir,
|
||||
"max_shard_size": f"{model_args.export_size}GB",
|
||||
}
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
save_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||
|
||||
model.save_pretrained(**save_kwargs)
|
||||
|
||||
if model_args.export_hub_model_id is not None:
|
||||
# Prepare push arguments (safe_serialization removed in transformers v5.0.0)
|
||||
push_kwargs = {
|
||||
"max_shard_size": f"{model_args.export_size}GB",
|
||||
}
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
push_kwargs["safe_serialization"] = not model_args.export_legacy_format
|
||||
|
||||
model.push_to_hub(
|
||||
model_args.export_hub_model_id,
|
||||
token=model_args.hf_hub_token,
|
||||
max_shard_size=f"{model_args.export_size}GB",
|
||||
safe_serialization=(not model_args.export_legacy_format),
|
||||
**push_kwargs,
|
||||
)
|
||||
|
||||
if finetuning_args.stage == "rm":
|
||||
|
||||
@@ -21,6 +21,7 @@ from omegaconf import OmegaConf
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..utils.env import is_env_enabled
|
||||
from ..utils.helper import set_seed
|
||||
from .data_args import DataArguments
|
||||
from .model_args import ModelArguments
|
||||
from .sample_args import SampleArguments
|
||||
@@ -56,6 +57,14 @@ def get_args(args: InputArgument = None) -> tuple[ModelArguments, DataArguments,
|
||||
print(f"Got unknown args, potentially deprecated arguments: {unknown_args}")
|
||||
raise ValueError(f"Some specified arguments are not used by the HfArgumentParser: {unknown_args}")
|
||||
|
||||
# Seed as early as possible after argument parsing so all downstream
|
||||
# components (dist init, dataloader, model init in run_* entrypoints) share the same RNG state.
|
||||
for arg in parsed_args:
|
||||
seed = getattr(arg, "seed", None)
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
break
|
||||
|
||||
return tuple(parsed_args)
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class TrainingArguments:
|
||||
metadata={"help": "Number of workers for batching."},
|
||||
)
|
||||
enable_activation_checkpointing: bool = field(
|
||||
default=True,
|
||||
default=False,
|
||||
metadata={"help": "Enable activation checkpointing for training."},
|
||||
)
|
||||
dist_config: PluginConfig | None = field(
|
||||
@@ -81,6 +81,10 @@ class TrainingArguments:
|
||||
default=None,
|
||||
metadata={"help": "Learning rate scheduler configuration for training."},
|
||||
)
|
||||
seed: int = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.dist_config = get_plugin_config(self.dist_config)
|
||||
|
||||
@@ -76,18 +76,27 @@ class BaseTrainer:
|
||||
if self.args.enable_activation_checkpointing:
|
||||
self.model.gradient_checkpointing_enable({"use_reentrant": False})
|
||||
|
||||
if self.args.dist_config is not None:
|
||||
shard_need_optimizer = self.args.dist_config.name == "deepspeed"
|
||||
else:
|
||||
shard_need_optimizer = False
|
||||
self._deepspeed_engine = None
|
||||
dist_name = self.args.dist_config.name if self.args.dist_config is not None else None
|
||||
|
||||
if shard_need_optimizer:
|
||||
if dist_name == "deepspeed":
|
||||
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||
|
||||
self._deepspeed_engine = DistributedPlugin("deepspeed")(
|
||||
self.model,
|
||||
self.args.dist_config,
|
||||
num_micro_batch=self.train_batch_generator.num_micro_batch,
|
||||
micro_batch_size=self.args.micro_batch_size,
|
||||
)
|
||||
self._init_optimizer()
|
||||
self._shard_model()
|
||||
self._init_lr_scheduler()
|
||||
self.model, self.optimizer, self.lr_scheduler = self._deepspeed_engine.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
else:
|
||||
# fsdp2 / DDP / no dist
|
||||
self._shard_model()
|
||||
self._init_optimizer()
|
||||
|
||||
self._init_lr_scheduler()
|
||||
|
||||
def _create_batch_generator(self) -> None:
|
||||
@@ -99,6 +108,7 @@ class BaseTrainer:
|
||||
cutoff_len=self.args.cutoff_len,
|
||||
batching_workers=self.args.batching_workers,
|
||||
batching_strategy=self.args.batching_strategy,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _shard_model(self) -> None:
|
||||
@@ -171,15 +181,25 @@ class BaseTrainer:
|
||||
step_loss = 0
|
||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||
for micro_batch in micro_batches:
|
||||
num_micro = len(micro_batches)
|
||||
for i, micro_batch in enumerate(micro_batches):
|
||||
loss = self.compute_loss(micro_batch)
|
||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||
|
||||
if self._deepspeed_engine is not None:
|
||||
# deepspeed: set sync_gradients so engine.step() only fires on last micro-batch
|
||||
self._deepspeed_engine.accelerator.sync_gradients = i == num_micro - 1
|
||||
self._deepspeed_engine.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
step_loss += loss.item()
|
||||
|
||||
if self._deepspeed_engine is not None:
|
||||
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
||||
grad_norm = self._deepspeed_engine.get_grad_norm()
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
||||
|
||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||
@@ -203,17 +223,14 @@ class BaseTrainer:
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save the model."""
|
||||
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
||||
from ..plugins.trainer_plugins.distributed.hub import DistributedPlugin
|
||||
|
||||
DistributedPlugin(self.args.dist_config.name).save_model(
|
||||
self.model, self.args.output_dir, self.renderer.processor
|
||||
)
|
||||
else:
|
||||
model_to_save = self.model.module if hasattr(self.model, "module") else self.model
|
||||
state_dict = None
|
||||
if self.args.dist_config is not None and self.args.dist_config.name == "fsdp2":
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
|
||||
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||
state_dict = get_model_state_dict(self.model, options=options)
|
||||
|
||||
if DistributedInterface().get_rank() != 0:
|
||||
return
|
||||
|
||||
model_to_save.save_pretrained(self.args.output_dir, state_dict=state_dict)
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir)
|
||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||
|
||||
@@ -90,6 +90,26 @@ class ModelEngine:
|
||||
Transformers can choose the proper model init context.
|
||||
https://github.com/huggingface/transformers/blob/v5.0.0rc0/src/transformers/modeling_utils.py#L3538
|
||||
"""
|
||||
if self.args.init_config is not None:
|
||||
from ..plugins.model_plugins.initialization import InitPlugin
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
init_kwargs = {"device_map": init_device}
|
||||
|
||||
if self.args.quant_config is not None:
|
||||
from ..plugins.model_plugins.quantization import QuantizationPlugin
|
||||
|
||||
init_kwargs = QuantizationPlugin(self.args.quant_config.name)(
|
||||
init_kwargs=init_kwargs,
|
||||
config=self.model_config,
|
||||
tokenizer=self.processor,
|
||||
model_args=self.args,
|
||||
is_trainable=self.is_train,
|
||||
)
|
||||
|
||||
if self.args.model_class == ModelClass.LLM:
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
@@ -107,14 +127,8 @@ class ModelEngine:
|
||||
|
||||
AutoClass = AutoModel
|
||||
|
||||
if self.args.init_config is not None:
|
||||
from ..plugins.model_plugins.initialization import InitPlugin
|
||||
|
||||
init_device = InitPlugin(self.args.init_config.name)()
|
||||
else:
|
||||
init_device = DistributedInterface().current_device
|
||||
|
||||
if init_device.type == DeviceType.META:
|
||||
assert self.args.quant_config is None, "Quantization is not supported with meta device."
|
||||
with init_empty_weights():
|
||||
model = AutoClass.from_config(self.model_config)
|
||||
else:
|
||||
@@ -122,8 +136,8 @@ class ModelEngine:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=init_device,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
**init_kwargs,
|
||||
)
|
||||
|
||||
if self.args.peft_config is None:
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.utils.data import default_collate
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
@@ -71,6 +72,7 @@ class BatchGenerator(Iterator):
|
||||
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
|
||||
pin_memory: bool = True,
|
||||
drop_last: bool = True,
|
||||
seed: int = 42,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.renderer = renderer
|
||||
@@ -82,6 +84,7 @@ class BatchGenerator(Iterator):
|
||||
self.batching_strategy = batching_strategy
|
||||
self.pin_memory = pin_memory
|
||||
self.drop_last = drop_last
|
||||
self.seed = seed
|
||||
# TODO: support length and infinity
|
||||
dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||
|
||||
@@ -128,12 +131,15 @@ class BatchGenerator(Iterator):
|
||||
num_replicas=DistributedInterface().get_world_size(Dim.DP),
|
||||
rank=DistributedInterface().get_rank(Dim.DP),
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
seed=self.seed,
|
||||
drop_last=self.drop_last,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Iterable dataset is not supported yet.")
|
||||
|
||||
generato_seed = torch.Generator()
|
||||
generato_seed.manual_seed(self.seed)
|
||||
|
||||
self._data_provider = StatefulDataLoader(
|
||||
self.dataset,
|
||||
batch_size=self.micro_batch_size * self.num_micro_batch,
|
||||
@@ -143,6 +149,7 @@ class BatchGenerator(Iterator):
|
||||
pin_memory=self.pin_memory,
|
||||
pin_memory_device=DistributedInterface().current_device.type,
|
||||
drop_last=self.drop_last,
|
||||
generator=generato_seed,
|
||||
)
|
||||
if self.batching_strategy == BatchingStrategy.NORMAL:
|
||||
self._length = len(self._data_provider)
|
||||
|
||||
@@ -150,6 +150,9 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
|
||||
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||
|
||||
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||
|
||||
if adapter_name_or_path:
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright 2025 HuggingFace Inc., the KVCache.AI team, Approaching AI, and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
from ...accelerator.helper import get_current_device
|
||||
from ...config.model_args import ModelArguments
|
||||
from ...utils import logging
|
||||
from ...utils.packages import check_version
|
||||
from ...utils.plugin import BasePlugin
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedTokenizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QuantizationPlugin(BasePlugin):
|
||||
r"""Plugin for model quantization."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
init_kwargs: dict[str, Any] = None,
|
||||
config: "PretrainedConfig" = None,
|
||||
tokenizer: "PreTrainedTokenizer" = None,
|
||||
model_args: "ModelArguments" = None,
|
||||
is_trainable: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
return super().__call__(
|
||||
init_kwargs, config=config, tokenizer=tokenizer, model_args=model_args, is_trainable=is_trainable
|
||||
)
|
||||
|
||||
|
||||
@QuantizationPlugin("auto").register()
|
||||
def quantization_auto(
|
||||
init_kwargs: dict[str, Any],
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""Automatic quantization selection, only support bnb currently.
|
||||
|
||||
Args:
|
||||
init_kwargs (dict[str, Any]): The kwargs for model initialization.
|
||||
**kwargs: Keyword arguments containing the model.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The updated kwargs for model initialization.
|
||||
"""
|
||||
model_args: ModelArguments = kwargs.get("model_args", None)
|
||||
quant_config = model_args.quant_config
|
||||
|
||||
quantization_bit = quant_config.get("quantization_bit", None)
|
||||
if quantization_bit is not None:
|
||||
logger.info_rank0(f"Loading {quantization_bit}-bit quantized model.")
|
||||
if quantization_bit in [8, 4]:
|
||||
return quantization_with_bnb(init_kwargs, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization bit: {quantization_bit} for auto quantization.")
|
||||
logger.warning_rank0("No quantization method applied.")
|
||||
return init_kwargs
|
||||
|
||||
|
||||
@QuantizationPlugin("bnb").register()
|
||||
def quantization_with_bnb(
|
||||
init_kwargs: dict[str, Any],
|
||||
model_args: "ModelArguments" = None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
r"""Quantization with BNB."""
|
||||
logger.info_rank0("Using Bitsandbytes quantization.")
|
||||
quantization_bit = model_args.quant_config.get("quantization_bit", None)
|
||||
if quantization_bit is None:
|
||||
logger.warning_rank0("quantization_bit is not specified, default to 8-bit quantization.")
|
||||
quantization_bit = 4
|
||||
assert quantization_bit in [8, 4], "Bitsandbytes only accepts 4-bit or 8-bit quantization."
|
||||
if quantization_bit == 8:
|
||||
check_version("bitsandbytes>=0.37.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif quantization_bit == 4:
|
||||
check_version("bitsandbytes>=0.39.0", mandatory=True)
|
||||
init_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=model_args.quant_config.get("compute_dtype", torch.float16),
|
||||
bnb_4bit_use_double_quant=model_args.quant_config.get("double_quantization", True),
|
||||
bnb_4bit_quant_type=model_args.quant_config.get("quantization_type", "nf4"),
|
||||
bnb_4bit_quant_storage=model_args.quant_config.get(
|
||||
"compute_dtype", torch.float16
|
||||
), # crucial for fsdp+qlora
|
||||
)
|
||||
else:
|
||||
raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.")
|
||||
|
||||
# TODO: improve deepspeed zero3 and fsdp detection.
|
||||
if kwargs.get("is_trainable", False):
|
||||
logger.info_rank0("Detected inference mode, setting device_map for bitsandbytes quantization.")
|
||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||
else:
|
||||
logger.info_rank0("Detected training mode, skip setting device_map for bitsandbytes quantization.")
|
||||
if model_args.quant_config.get("quantization_bit") != 4:
|
||||
raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.")
|
||||
|
||||
check_version("bitsandbytes>=0.43.0", mandatory=True)
|
||||
|
||||
logger.info_rank0(f"Quantizing model to {model_args.quant_config.get('quantization_bit')} bit with bitsandbytes.")
|
||||
return init_kwargs
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""DeepSpeed integration via accelerate's built-in capabilities.
|
||||
|
||||
Instead of manually calling deepspeed.initialize() and syncing config,
|
||||
this module leverages accelerate's Accelerator + DeepSpeedPlugin to handle
|
||||
initialization, backward, gradient accumulation, and model saving.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.types import HFModel, Processor
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DeepSpeedEngine:
|
||||
"""DeepSpeed integration using accelerate's built-in capabilities.
|
||||
|
||||
This replaces the manual DeepSpeedConfigHelper / DeepSpeedEngine approach
|
||||
with accelerate's Accelerator + DeepSpeedPlugin, which handles:
|
||||
- Config syncing (auto values, batch size, lr, etc.)
|
||||
- deepspeed.initialize() call
|
||||
- Optimizer / LR scheduler wrapping
|
||||
- Backward + gradient accumulation boundary
|
||||
- ZeRO-3 parameter gathering for saving
|
||||
"""
|
||||
|
||||
def __init__(self, dist_config: dict[str, Any], num_micro_batch: int = 1, micro_batch_size: int = 1):
|
||||
config_file = dist_config.get("config_file")
|
||||
if not config_file:
|
||||
raise ValueError("DeepSpeed config_file is required in dist_config")
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=config_file)
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
gradient_accumulation_steps=num_micro_batch,
|
||||
)
|
||||
|
||||
# Resolve "auto" for train_micro_batch_size_per_gpu so that
|
||||
# accelerate.prepare() does not require a DataLoader to infer it.
|
||||
ds_config = self.accelerator.state.deepspeed_plugin.deepspeed_config
|
||||
if ds_config.get("train_micro_batch_size_per_gpu") in (None, "auto"):
|
||||
ds_config["train_micro_batch_size_per_gpu"] = micro_batch_size
|
||||
|
||||
logger.info_rank0(f"DeepSpeedEngine initialized with config: {config_file}")
|
||||
|
||||
def shard_model(self, model: HFModel) -> "DeepSpeedEngine":
|
||||
"""No-op shard — actual model wrapping happens in prepare().
|
||||
|
||||
Returns self so the caller gets the engine instance via the hub interface.
|
||||
"""
|
||||
return self
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
model: HFModel,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: Optional[Any] = None,
|
||||
) -> tuple[HFModel, torch.optim.Optimizer, Any]:
|
||||
"""Prepare model, optimizer, and lr_scheduler using accelerate.
|
||||
|
||||
Internally calls deepspeed.initialize() and wraps the returned objects.
|
||||
"""
|
||||
if lr_scheduler is not None:
|
||||
model, optimizer, lr_scheduler = self.accelerator.prepare(model, optimizer, lr_scheduler)
|
||||
else:
|
||||
model, optimizer = self.accelerator.prepare(model, optimizer)
|
||||
|
||||
model._accelerator = self.accelerator # type: ignore[assignment]
|
||||
|
||||
logger.info_rank0("Model, optimizer, and lr_scheduler prepared via accelerate")
|
||||
return model, optimizer, lr_scheduler
|
||||
|
||||
def backward(self, loss: torch.Tensor) -> None:
|
||||
"""Backward pass using accelerate.
|
||||
|
||||
Delegates to DeepSpeedEngineWrapper.backward() which respects
|
||||
sync_gradients to control gradient accumulation boundaries.
|
||||
When sync_gradients=True: engine.backward(loss) + engine.step()
|
||||
When sync_gradients=False: engine.backward(loss) only
|
||||
"""
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
def get_grad_norm(self) -> float:
|
||||
"""Get the global gradient norm from the DeepSpeed engine."""
|
||||
engine_wrapper = getattr(self.accelerator, "deepspeed_engine_wrapped", None)
|
||||
if engine_wrapper is not None:
|
||||
return engine_wrapper.engine.get_global_grad_norm() or 0.0
|
||||
return 0.0
|
||||
|
||||
|
||||
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
"""Save model using accelerate's built-in ZeRO-aware utilities.
|
||||
|
||||
Expects model._accelerator to be set during prepare().
|
||||
Handles ZeRO-3 parameter gathering automatically via
|
||||
accelerator.get_state_dict().
|
||||
"""
|
||||
accelerator: Accelerator = model._accelerator # type: ignore[union-attr]
|
||||
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
state_dict = accelerator.get_state_dict(model)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
|
||||
processor.save_pretrained(output_dir, max_shard_size="4GB")
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
logger.info_rank0(f"Model saved to {output_dir}")
|
||||
|
||||
@@ -12,29 +12,30 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from peft.tuners.lora import LoraLayer
|
||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||
from torch.distributed.fsdp import (
|
||||
CPUOffloadPolicy,
|
||||
MixedPrecisionPolicy,
|
||||
fully_shard,
|
||||
)
|
||||
from transformers import PreTrainedModel
|
||||
from peft.tuners.lora import LoraLayer
|
||||
|
||||
from ....accelerator.helper import get_current_accelerator
|
||||
from ....accelerator.interface import DistributedInterface
|
||||
from ....utils.logging import get_logger
|
||||
from ....utils.types import HFModel, Processor
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
|
||||
def get_transformer_layer_cls(model: HFModel) -> type[nn.Module] | None:
|
||||
no_split_modules = getattr(model, "_no_split_modules", None)
|
||||
if no_split_modules:
|
||||
if isinstance(no_split_modules, (list, tuple)):
|
||||
@@ -50,6 +51,20 @@ def get_transformer_layer_cls(model: PreTrainedModel) -> type[nn.Module] | None:
|
||||
return None
|
||||
|
||||
|
||||
def save_model(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
logger.info("Gathering state dict for saving...")
|
||||
|
||||
options = StateDictOptions(full_state_dict=True, cpu_offload=True)
|
||||
state_dict = get_model_state_dict(model, options=options)
|
||||
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
model_to_save = model.module if hasattr(model, "module") else model
|
||||
model_to_save.save_pretrained(output_dir, state_dict=state_dict, max_shard_size="4GB")
|
||||
processor.save_pretrained(output_dir, max_shard_size="4GB")
|
||||
logger.info(f"Model saved to {output_dir}")
|
||||
|
||||
|
||||
class FSDP2Engine:
|
||||
def __init__(self, dist_config: dict):
|
||||
self.dist_interface = DistributedInterface()
|
||||
@@ -95,11 +110,10 @@ class FSDP2Engine:
|
||||
cast_forward_inputs=True,
|
||||
)
|
||||
|
||||
|
||||
def is_lora_module_wrap(self, model) -> bool:
|
||||
return any(isinstance(module, LoraLayer) for module in model.modules())
|
||||
|
||||
def prepare_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||
def prepare_model(self, model: HFModel) -> HFModel:
|
||||
if self.fsdp_mesh is None:
|
||||
logger.warning("No FSDP Mesh available, skipping FSDP wrapping.")
|
||||
return model
|
||||
@@ -119,7 +133,6 @@ class FSDP2Engine:
|
||||
if self.is_lora_module_wrap(model):
|
||||
lora_modules = []
|
||||
for module in model.modules():
|
||||
|
||||
if len(list(module.children())) != 0:
|
||||
continue
|
||||
if any(param.requires_grad for param in module.parameters(recurse=False)):
|
||||
@@ -134,7 +147,7 @@ class FSDP2Engine:
|
||||
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||
)
|
||||
|
||||
logger.info(f"Applying FSDP wrap for LoRA layer separately.")
|
||||
logger.info("Applying FSDP wrap for LoRA layer separately.")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
should_wrap = False
|
||||
@@ -154,12 +167,11 @@ class FSDP2Engine:
|
||||
offload_policy=CPUOffloadPolicy(pin_memory=self.pin_memory) if self.offload_params else None,
|
||||
)
|
||||
|
||||
use_gradient_checkpointing = True # Could be configurable
|
||||
if use_gradient_checkpointing:
|
||||
# BaseTrainer is the single source of truth for gradient checkpointing.
|
||||
# FSDP2 only applies the input-grad compatibility hook when checkpointing is already enabled.
|
||||
if getattr(model, "is_gradient_checkpointing", False):
|
||||
if self.rank == 0:
|
||||
logger.info("Enabling gradient checkpointing (transformers native)...")
|
||||
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
logger.info("Gradient checkpointing is enabled. Applying FSDP2 input grad preparation.")
|
||||
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
@@ -179,8 +191,9 @@ class FSDP2Engine:
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def materialize_and_load(self, model: PreTrainedModel, hf_model_path: str, dcp_path: str = None):
|
||||
def materialize_and_load(self, model: HFModel, hf_model_path: str, dcp_path: str = None):
|
||||
if self.rank == 0:
|
||||
logger.info("Materializing sharded model params...")
|
||||
|
||||
@@ -200,15 +213,57 @@ class FSDP2Engine:
|
||||
|
||||
return model
|
||||
|
||||
def shard_model(self, model: PreTrainedModel) -> PreTrainedModel:
|
||||
def _save_non_persistent_buffers(self, model: HFModel) -> dict:
|
||||
"""Save non-persistent buffers, such as inv_freq."""
|
||||
saved = {}
|
||||
for mod_name, module in model.named_modules():
|
||||
for buf_name in module._non_persistent_buffers_set:
|
||||
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||
buf = getattr(module, buf_name, None)
|
||||
if buf is not None:
|
||||
saved[fqn] = copy.deepcopy(buf)
|
||||
if self.rank == 0 and saved:
|
||||
logger.info(f"Saved {len(saved)} non-persistent buffers")
|
||||
return saved
|
||||
|
||||
def _restore_non_persistent_buffers(self, model: HFModel, saved_buffers: dict):
|
||||
"""Register saved non-persistent buffers to model."""
|
||||
if not saved_buffers:
|
||||
return
|
||||
device = get_current_accelerator()
|
||||
for fqn, buf in saved_buffers.items():
|
||||
buf = buf.to(device)
|
||||
if "." in fqn:
|
||||
parent_fqn, buf_name = fqn.rsplit(".", 1)
|
||||
parent_module = model.get_submodule(parent_fqn)
|
||||
else:
|
||||
buf_name = fqn
|
||||
parent_module = model
|
||||
parent_module.register_buffer(buf_name, buf, persistent=False)
|
||||
if self.rank == 0:
|
||||
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
||||
|
||||
def shard_model(self, model: HFModel) -> HFModel:
|
||||
if model.device.type == "meta":
|
||||
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
||||
|
||||
if getattr(model.config, "tie_word_embeddings", None):
|
||||
model.tie_weights()
|
||||
|
||||
model = self.prepare_model(model)
|
||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||
|
||||
# fix tied broken for no-fsdp-wrap case
|
||||
if getattr(model.config, "tie_word_embeddings", None):
|
||||
model.tie_weights()
|
||||
|
||||
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
||||
|
||||
else:
|
||||
model = self.prepare_model(model)
|
||||
return model
|
||||
|
||||
def _load_from_dcp(self, model: PreTrainedModel, dcp_path: str):
|
||||
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||
import torch.distributed.checkpoint as dcp
|
||||
|
||||
try:
|
||||
@@ -227,7 +282,7 @@ class FSDP2Engine:
|
||||
logger.error(f"Failed to load from DCP: {e}")
|
||||
raise e
|
||||
|
||||
def _load_weights_from_hf_checkpoint(self, model, hf_model_path):
|
||||
def _load_weights_from_hf_checkpoint(self, model: HFModel, hf_model_path: str):
|
||||
import glob
|
||||
import json
|
||||
|
||||
|
||||
@@ -12,9 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ....config.arg_utils import PluginConfig
|
||||
from ....utils.plugin import BasePlugin
|
||||
from ....utils.types import HFModel
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ....utils.types import HFModel, Processor
|
||||
|
||||
|
||||
class DistributedPlugin(BasePlugin):
|
||||
@@ -23,12 +30,32 @@ class DistributedPlugin(BasePlugin):
|
||||
|
||||
|
||||
@DistributedPlugin("fsdp2").register()
|
||||
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig) -> HFModel:
|
||||
def shard_model_fsdp2(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||
from .fsdp2 import FSDP2Engine
|
||||
|
||||
return FSDP2Engine(dist_config).shard_model(model)
|
||||
|
||||
|
||||
@DistributedPlugin("fsdp2").register("save_model")
|
||||
def save_model_fsdp2(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
from .fsdp2 import save_model
|
||||
|
||||
return save_model(model, output_dir, processor)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register()
|
||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig) -> HFModel:
|
||||
return model
|
||||
def shard_model_deepspeed(model: HFModel, dist_config: PluginConfig, **kwargs) -> HFModel:
|
||||
from .deepspeed import DeepSpeedEngine
|
||||
|
||||
return DeepSpeedEngine(
|
||||
dist_config,
|
||||
num_micro_batch=kwargs.get("num_micro_batch"),
|
||||
micro_batch_size=kwargs.get("micro_batch_size"),
|
||||
).shard_model(model)
|
||||
|
||||
|
||||
@DistributedPlugin("deepspeed").register("save_model")
|
||||
def save_model_deepspeed(model: HFModel, output_dir: str, processor: Processor) -> None:
|
||||
from .deepspeed import save_model
|
||||
|
||||
return save_model(model, output_dir, processor)
|
||||
|
||||
@@ -15,12 +15,22 @@
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import set_seed as hf_set_seed
|
||||
|
||||
from ..accelerator.interface import DistributedInterface
|
||||
from .constants import IGNORE_INDEX
|
||||
from .types import BatchInput, ModelInput, Processor, Tensor
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
"""Set seed for reproducibility.
|
||||
|
||||
Args:
|
||||
seed: Random seed.
|
||||
"""
|
||||
hf_set_seed(seed)
|
||||
|
||||
|
||||
def is_tokenizer(processor: Processor) -> bool:
|
||||
"""Check if processor is tokenizer.
|
||||
|
||||
|
||||
@@ -21,6 +21,13 @@ from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from packaging import version
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from . import logging
|
||||
from .env import is_env_enabled
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -41,3 +48,22 @@ def _get_package_version(name: str) -> "Version":
|
||||
@lru_cache
|
||||
def is_transformers_version_greater_than(content: str):
|
||||
return _get_package_version("transformers") >= version.parse(content)
|
||||
|
||||
|
||||
def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
r"""Optionally check the package version."""
|
||||
if is_env_enabled("DISABLE_VERSION_CHECK") and not mandatory:
|
||||
logger.warning_rank0_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
return
|
||||
|
||||
if "gptqmodel" in requirement or "autoawq" in requirement:
|
||||
pip_command = f"pip install {requirement} --no-build-isolation"
|
||||
else:
|
||||
pip_command = f"pip install {requirement}"
|
||||
|
||||
if mandatory:
|
||||
hint = f"To fix: run `{pip_command}`."
|
||||
else:
|
||||
hint = f"To fix: run `{pip_command}` or set `DISABLE_VERSION_CHECK=1` to skip this check."
|
||||
|
||||
require_version(requirement, hint)
|
||||
|
||||
@@ -108,11 +108,26 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
with gr.Column():
|
||||
enable_thinking = gr.Checkbox(value=True)
|
||||
report_to = gr.Dropdown(
|
||||
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"],
|
||||
choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "trackio", "all"],
|
||||
value="none",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
|
||||
with gr.Accordion("Trackio Settings", open=False):
|
||||
project = gr.Textbox(
|
||||
value="huggingface",
|
||||
label="Project Name",
|
||||
info="Project name for experiment tracking (used by Trackio, W&B, etc.)",
|
||||
)
|
||||
|
||||
trackio_space_id = gr.Textbox(
|
||||
value="trackio", label="Trackio Space ID", info="Hugging Face Space ID for Trackio deployment"
|
||||
)
|
||||
|
||||
hub_private_repo = gr.Checkbox(
|
||||
value=False, label="Private Repository", info="Make the Hugging Face repository private"
|
||||
)
|
||||
|
||||
input_elems.update(
|
||||
{
|
||||
logging_steps,
|
||||
@@ -128,6 +143,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
use_llama_pro,
|
||||
enable_thinking,
|
||||
report_to,
|
||||
project,
|
||||
trackio_space_id,
|
||||
hub_private_repo,
|
||||
}
|
||||
)
|
||||
elem_dict.update(
|
||||
@@ -146,6 +164,9 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
|
||||
use_llama_pro=use_llama_pro,
|
||||
enable_thinking=enable_thinking,
|
||||
report_to=report_to,
|
||||
project=project,
|
||||
trackio_space_id=trackio_space_id,
|
||||
hub_private_repo=hub_private_repo,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -166,3 +166,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
def fix_valuehead_cpu_loading():
|
||||
"""Fix valuehead model loading."""
|
||||
patch_valuehead_model()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def bypass_mistral_regex_check():
|
||||
"""Disable Mistral regex network check.
|
||||
|
||||
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
|
||||
"""
|
||||
try:
|
||||
from transformers.tokenization_utils_fast import TokenizersBackend
|
||||
except ImportError:
|
||||
# Very old transformers, nothing to patch
|
||||
yield
|
||||
return
|
||||
|
||||
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
|
||||
# Method does not exist in this version
|
||||
yield
|
||||
return
|
||||
|
||||
# Backup original method
|
||||
original = TokenizersBackend._patch_mistral_regex
|
||||
|
||||
# Replace with no-op
|
||||
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
|
||||
|
||||
yield
|
||||
|
||||
# Restore original method
|
||||
TokenizersBackend._patch_mistral_regex = original
|
||||
|
||||
@@ -22,6 +22,7 @@ from transformers import AutoConfig, AutoModelForImageTextToText
|
||||
from llamafactory.data import get_template_and_fix_tokenizer
|
||||
from llamafactory.data.collator import MultiModalDataCollatorForSeq2Seq, prepare_4d_attention_mask
|
||||
from llamafactory.extras.constants import IGNORE_INDEX
|
||||
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||
from llamafactory.hparams import get_infer_args
|
||||
from llamafactory.model import load_tokenizer
|
||||
|
||||
@@ -116,14 +117,16 @@ def test_multimodal_collator():
|
||||
"labels": [
|
||||
[0, 1, 2, 3, q, q, q, q, q, q, q, q],
|
||||
],
|
||||
"position_ids": [
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]],
|
||||
],
|
||||
"rope_deltas": [[-8]],
|
||||
"position_ids": [[[0, 1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0]]] * 3,
|
||||
"rope_deltas": [[0]],
|
||||
**tokenizer_module["processor"].image_processor(fake_image),
|
||||
}
|
||||
if not is_transformers_version_greater_than("5.0.0"):
|
||||
# adapt position_ids and rope_deltas for transformers < 5.0.0
|
||||
# https://github.com/huggingface/transformers/pull/43972
|
||||
expected_input["position_ids"] = [[[0, 1, 2, 3, 1, 1, 1, 1, 1, 1, 1, 1]]] * 3
|
||||
expected_input["rope_deltas"] = [[-8]]
|
||||
|
||||
assert batch_input.keys() == expected_input.keys()
|
||||
for k in batch_input.keys():
|
||||
assert batch_input[k].eq(torch.tensor(expected_input[k])).all()
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# change if test fails or cache is outdated
|
||||
0.9.5.106
|
||||
0.9.5.107
|
||||
|
||||
@@ -172,3 +172,33 @@ def _manage_distributed_env(request: FixtureRequest, monkeypatch: MonkeyPatch) -
|
||||
monkeypatch.setattr(torch.cuda, "device_count", lambda: 1)
|
||||
elif CURRENT_DEVICE == "npu":
|
||||
monkeypatch.setattr(torch.npu, "device_count", lambda: 1)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def bypass_mistral_regex_check():
|
||||
"""Disable Mistral regex network check.
|
||||
|
||||
Monkey-patch TokenizersBackend._patch_mistral_regex into a no-op.
|
||||
"""
|
||||
try:
|
||||
from transformers.tokenization_utils_fast import TokenizersBackend
|
||||
except ImportError:
|
||||
# Very old transformers, nothing to patch
|
||||
yield
|
||||
return
|
||||
|
||||
if not hasattr(TokenizersBackend, "_patch_mistral_regex"):
|
||||
# Method does not exist in this version
|
||||
yield
|
||||
return
|
||||
|
||||
# Backup original method
|
||||
original = TokenizersBackend._patch_mistral_regex
|
||||
|
||||
# Replace with no-op
|
||||
TokenizersBackend._patch_mistral_regex = lambda cls, tokenizer, *args, **kwargs: tokenizer
|
||||
|
||||
yield
|
||||
|
||||
# Restore original method
|
||||
TokenizersBackend._patch_mistral_regex = original
|
||||
|
||||
51
tests_v1/plugins/model_plugins/test_quantization_plugin.py
Normal file
51
tests_v1/plugins/model_plugins/test_quantization_plugin.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.v1.config.model_args import ModelArguments
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
|
||||
|
||||
bitsandbytes = pytest.importorskip("bitsandbytes")
|
||||
|
||||
|
||||
def check_quantization_status(model):
|
||||
quantized_info = {"bnb": []}
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# check BitsAndBytes quantization
|
||||
if isinstance(module, bitsandbytes.nn.modules.Linear8bitLt) or isinstance(
|
||||
module, bitsandbytes.nn.modules.Linear4bit
|
||||
):
|
||||
quantized_info["bnb"].append(name)
|
||||
|
||||
return quantized_info
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cuda"])
|
||||
@pytest.mark.parametrize("name, quantization_bit", [("bnb", 4), ("auto", 4)])
|
||||
def test_quantization_plugin(name, quantization_bit):
|
||||
model_args = ModelArguments(
|
||||
model="llamafactory/tiny-random-qwen3",
|
||||
quant_config={
|
||||
"name": name,
|
||||
"quantization_bit": quantization_bit,
|
||||
},
|
||||
)
|
||||
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
quantized_info = check_quantization_status(model_engine.model)
|
||||
print(f"Quantized weights for method {name} with {quantization_bit} bit: {quantized_info}")
|
||||
assert any(v for v in quantized_info.values()), "model is not quantized properly."
|
||||
104
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
Normal file
104
tests_v1/plugins/trainer_plugins/distributed/test_fsdp2.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests: FSDP2 meta-device loading vs normal loading consistency.
|
||||
|
||||
Validates that the FSDP2 meta loading path behaves correctly for tied weights
|
||||
and non-persistent buffers by comparing it with the standard non-meta path.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers import AutoConfig
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
from llamafactory.v1.plugins.trainer_plugins.distributed.fsdp2 import FSDP2Engine
|
||||
|
||||
|
||||
TINY_MODEL = "llamafactory/tiny-random-qwen3"
|
||||
|
||||
|
||||
def collect_non_persistent_buffers(model):
|
||||
"""Collect all non-persistent buffers from model."""
|
||||
result = {}
|
||||
for mod_name, module in model.named_modules():
|
||||
for buf_name in getattr(module, "_non_persistent_buffers_set", set()):
|
||||
fqn = f"{mod_name}.{buf_name}" if mod_name else buf_name
|
||||
buf = getattr(module, buf_name, None)
|
||||
if buf is not None:
|
||||
result[fqn] = buf.detach().cpu().clone()
|
||||
return result
|
||||
|
||||
|
||||
def test_fsdp2_meta_loading_buffers_and_tied_weights():
|
||||
"""Verify non-persistent buffers and tied weights consistency after meta load."""
|
||||
# 1. Initialize DistributedInterface for single process
|
||||
DistributedInterface()
|
||||
|
||||
# 2. Build FSDP2Engine config
|
||||
engine = FSDP2Engine(
|
||||
{
|
||||
"name": "fsdp2",
|
||||
"mixed_precision": "bf16",
|
||||
"reshard_after_forward": True,
|
||||
"offload_params": False,
|
||||
"pin_memory": False,
|
||||
"dcp_path": None,
|
||||
}
|
||||
)
|
||||
|
||||
config = AutoConfig.from_pretrained(TINY_MODEL)
|
||||
|
||||
# --- NORMAL PATH ---
|
||||
normal_args, *_ = get_args(dict(model=TINY_MODEL, init_config=None))
|
||||
normal_engine = ModelEngine(model_args=normal_args)
|
||||
normal_model = normal_engine.model.to(torch.bfloat16)
|
||||
|
||||
normal_model = engine.shard_model(normal_model)
|
||||
normal_non_persistent = collect_non_persistent_buffers(normal_model)
|
||||
|
||||
del normal_model
|
||||
|
||||
# --- META PATH ---
|
||||
meta_args, *_ = get_args(dict(model=TINY_MODEL, init_config={"name": "init_on_meta"}))
|
||||
meta_model_engine = ModelEngine(model_args=meta_args)
|
||||
meta_model = meta_model_engine.model
|
||||
|
||||
assert meta_model.device.type == "meta", "Model should be on meta device"
|
||||
|
||||
# Process meta device: save buffers -> tie_weights -> load from checkpoint -> restore buffers
|
||||
meta_model = engine.shard_model(meta_model)
|
||||
meta_non_persistent = collect_non_persistent_buffers(meta_model)
|
||||
|
||||
# 3. Tied weights (embed_tokens.weight and lm_head.weight)
|
||||
|
||||
tie_word_embeddings = getattr(config, "tie_word_embeddings", False)
|
||||
if tie_word_embeddings:
|
||||
assert meta_model.lm_head.weight is meta_model.model.embed_tokens.weight, (
|
||||
"Weights should be tied after loading"
|
||||
)
|
||||
|
||||
del meta_model
|
||||
|
||||
# 4. Non-persistent buffers (e.g., inv_freq)
|
||||
normal_buf_keys = set(normal_non_persistent.keys())
|
||||
meta_buf_keys = set(meta_non_persistent.keys())
|
||||
assert normal_buf_keys == meta_buf_keys, "Non-persistent buffer keys mismatch"
|
||||
|
||||
for key in sorted(normal_buf_keys & meta_buf_keys):
|
||||
nb = normal_non_persistent[key]
|
||||
mb = meta_non_persistent[key]
|
||||
assert nb.shape == mb.shape, f"Buffer shape mismatch: {key}"
|
||||
assert torch.allclose(nb.float(), mb.float(), atol=1e-5), f"Buffer value mismatch: {key}"
|
||||
Reference in New Issue
Block a user