12 Commits

Author SHA1 Message Date
Kingsley
436d26bc28 fix: projector lookup for gemma4 modules (#10382)
Co-authored-by: yiluoAK_47 <yiluoAK_47@163.com>
2026-04-12 08:32:14 +08:00
Kingsley
c109c061e5 [model] set mm_projectors for omni models (#10378) 2026-04-10 18:12:57 +08:00
Kingsley
fa09c01c36 fix: gemma4 mm_token_type_ids padding (#10359) 2026-04-06 13:14:45 +08:00
Kingsley
eae6f0b541 [model] gemma4 (#10346) 2026-04-05 12:10:28 +08:00
Kingsley
acac63ef35 [data] fix qwen3vl timestamp (#10338) 2026-04-01 22:40:12 +08:00
浮梦
e5e8546493 [misc] fix moe (#10334)
Co-authored-by: frozenleaves <frozen@Mac.local>
2026-03-31 23:04:45 +08:00
Cui-yshoho
97433c53b6 [feat] support LlamaFactory SFT training by HyperParallel FSDP2 backend (#10289) 2026-03-30 10:47:20 +08:00
sunyi0505
b5afabe3d2 [v1] support ulysses cp for fsdp2 (#10262) 2026-03-27 16:22:48 +08:00
jiaqiw09
df2e6edb7e [v1] add init on rank0 for fsdp2 (#10264) 2026-03-27 14:54:03 +08:00
Goalina
d02fcd3588 [ci] add nginx cache config for Ascend NPU CI environment (#10323) 2026-03-27 10:04:16 +08:00
jiaqiw09
c340aa2a33 [v1] add callbacks (#10255) 2026-03-26 19:59:57 +08:00
Hertz
1e536733c6 [data] fix mimo-v2 tool call (#10315) 2026-03-26 17:37:22 +08:00
38 changed files with 1820 additions and 63 deletions

105
.ai/CLAUDE.md Normal file
View File

@@ -0,0 +1,105 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Commands
```bash
# Code style (auto-fix)
make style
# Code quality check (no modifications)
make quality
# Run all tests
make test
# Run a single test file
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/path/to/test_file.py
# Run tests matching a pattern
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ -k "test_name"
# License header check
make license
# Build package
make build
```
The project uses `uv` as the preferred package manager. Commands automatically use `uv run` / `uvx` if `uv` is available.
## Architecture
LlamaFactory has two parallel architectures controlled by the `USE_V1` environment variable:
- **v0 (default):** `api, webui > chat, eval, train > data, model > hparams > extras`
- **v1 (experimental, `USE_V1=1`):** `trainers > core > accelerator, plugins, config > utils`
Most active development happens in v0. The v1 architecture lives in `src/llamafactory/v1/`.
### Entry Points
CLI entry point is `llamafactory-cli` / `lmf``src/llamafactory/cli.py:main()`, which dispatches to `launcher.py` based on `USE_V1`.
Available subcommands: `train`, `chat`, `api`, `export`, `webchat`, `webui`, `env`, `version`, `help`.
### Training Flow (v0)
```
run_exp() [tuner.py]
→ read_args() → parse YAML/JSON config
→ get_train_args() → produces typed argument dataclasses
→ routes to: run_sft / run_dpo / run_ppo / run_rm / run_pt / run_kto
→ optional: export_model()
```
Training is invoked with a YAML config: `llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml`
### Configuration System
All training parameters are YAML/JSON config files. Argument parsing in `src/llamafactory/hparams/parser.py` produces four typed dataclasses:
- `ModelArguments` — model/tokenizer selection, quantization
- `DataArguments` — datasets, templates, preprocessing
- `FinetuningArguments` — LoRA rank/target, training method (sft/dpo/ppo/rm/pt/kto)
- `TrainingArguments` — extends HuggingFace's `TrainingArguments`
### Key Modules
| Module | Purpose |
|--------|---------|
| `src/llamafactory/model/loader.py` | Loads model + tokenizer; applies quantization, LoRA, patches |
| `src/llamafactory/model/patcher.py` | Model-specific compatibility patches |
| `src/llamafactory/data/template.py` | Prompt templates; `TEMPLATES` dict maps model family → format |
| `src/llamafactory/data/mm_plugin.py` | Multi-modal (image/video/audio) data handling |
| `src/llamafactory/data/processor/` | Per-stage data processors (supervised, pairwise, pretrain, etc.) |
| `src/llamafactory/train/sft/` | SFT trainer; other stages follow same structure |
| `src/llamafactory/chat/` | Inference engines: `hf_engine`, `vllm_engine`, `sglang_engine`, `kt_engine` |
| `src/llamafactory/extras/constants.py` | Enums and constants used across the project |
### Adding Support for a New Model
1. Add a prompt template to `src/llamafactory/data/template.py` in the `TEMPLATES` dict
2. Add any necessary model patches in `src/llamafactory/model/patcher.py`
3. Add multi-modal support in `src/llamafactory/data/mm_plugin.py` if needed
### Distributed Training
Multi-GPU automatically uses `torchrun`. Additional backends:
- **Ray:** Optional Ray cluster support
- **HyperParallel FSDP2:** `src/llamafactory/train/hyper_parallel/`
- **Megatron-core:** `src/llamafactory/train/mca/`
### Testing
- `tests/` — v0 tests; `tests_v1/` — v1 tests
- Most training tests require GPU hardware
- pytest markers: `@pytest.mark.slow`, `@pytest.mark.runs_on(['cuda'])`
- Always set `WANDB_DISABLED=true` when running tests
### Code Style
- Ruff for linting and formatting (line length 119, Google-style docstrings)
- Python 3.11+ syntax
- Double quotes for strings
- All new files must include Apache 2.0 license header (checked by `make license`)

View File

@@ -49,6 +49,12 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v6 uses: actions/checkout@v6
- name: Set nginx-cache for Ascend CI
run: |
sed -Ei 's@(ports|archive).ubuntu.com@cache-service.nginx-pypi-cache.svc.cluster.local:8081@g' /etc/apt/sources.list
pip config set global.index-url http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple
pip config set global.trusted-host cache-service.nginx-pypi-cache.svc.cluster.local
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@v7 uses: astral-sh/setup-uv@v7
with: with:

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-4B model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -0,0 +1,23 @@
model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm
template: qwen3_nothink
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
cp_mode: ulysses
cp_size: 2
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: outputs/test_ulysses_cp
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: false
max_steps: 10

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-4B model: Qwen/Qwen3-4B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink
@@ -28,7 +27,6 @@ train_dataset: data/v1_sft_demo.yaml
### training ### training
output_dir: ./outputs/test_lora output_dir: ./outputs/test_lora
micro_batch_size: 1 micro_batch_size: 1
global_batch_size: 4
cutoff_len: 2048 cutoff_len: 2048
learning_rate: 1.0e-4 learning_rate: 1.0e-4
bf16: true bf16: true

View File

@@ -0,0 +1,40 @@
model: Qwen/Qwen3-4B
model_class: llm
template: qwen3_nothink
# PEFT Configuration
peft_config:
name: lora
r: 16
lora_alpha: 32
lora_dropout: 0.05
target_modules: all
# Kernel Config
kernel_config:
name: auto
include_kernels: auto
# FSDP Config
dist_config:
name: fsdp2
dcp_path: null
init_config:
name: init_on_rank0
### data
train_dataset: data/v1_sft_demo.yaml
### training
output_dir: ./outputs/test_lora
micro_batch_size: 1
cutoff_len: 2048
learning_rate: 1.0e-4
bf16: true
max_steps: 10
### sample
sample_backend: hf
max_new_tokens: 128

View File

@@ -1,5 +1,4 @@
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
trust_remote_code: true
model_class: llm model_class: llm
template: qwen3_nothink template: qwen3_nothink

View File

@@ -380,6 +380,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for i, feature in enumerate(features): for i, feature in enumerate(features):
feature["token_type_ids"] = token_type_ids[i] feature["token_type_ids"] = token_type_ids[i]
if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4
mm_token_type_ids = mm_inputs.pop("mm_token_type_ids")
max_len = max(len(ids) for ids in mm_token_type_ids)
padded = []
for ids in mm_token_type_ids:
pad_len = max_len - len(ids)
if self.tokenizer.padding_side == "right":
padded.append(ids + [0] * pad_len)
else:
padded.append([0] * pad_len + ids)
mm_inputs["mm_token_type_ids"] = torch.tensor(padded, dtype=torch.long)
features: dict[str, torch.Tensor] = super().__call__(features) features: dict[str, torch.Tensor] = super().__call__(features)
bsz, seq_len = features["input_ids"].shape[:2] bsz, seq_len = features["input_ids"].shape[:2]

View File

@@ -607,6 +607,194 @@ class Gemma3nPlugin(Gemma3Plugin):
return messages return messages
@dataclass
class Gemma4Plugin(BasePlugin):
r"""Plugin for the Gemma4 multimodal model."""
@override
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
r"""Regularize videos, also tracking per-video FPS and frame indices for timestamp generation."""
results, fps_per_video, durations, frames_indices = [], [], [], []
for video in videos:
frames: list[ImageObject] = []
if _check_video_is_nested_images(video):
frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
frames_indices.append(list(range(len(frames))))
else:
container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
original_fps = float(video_stream.average_rate)
# for correctly calculate timestamps
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices])
container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices:
frames.append(frame.to_image())
if video_stream.duration is None:
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
else:
durations.append(float(video_stream.duration * video_stream.time_base))
frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
@override
def _get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: "MMProcessor",
) -> dict[str, Union[list[int], "torch.Tensor"]]:
image_processor = getattr(processor, "image_processor", None)
video_processor = getattr(processor, "video_processor", None)
feature_extractor = getattr(processor, "feature_extractor", None)
mm_inputs = {}
if len(images) != 0:
regularized = self._regularize_images(
images,
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
)["images"]
mm_inputs.update(image_processor(regularized, return_tensors="pt"))
if len(videos) != 0:
video_data = self._regularize_videos(
videos,
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_metadata = [
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"])
]
mm_inputs.update(
video_processor(
videos=video_data["videos"],
video_metadata=video_metadata,
return_tensors="pt",
return_metadata=True,
do_sample_frames=False,
)
)
if len(audios) != 0: # only for gemma4n
audios = self._regularize_audios(
audios,
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
)["audios"]
mm_inputs.update(
feature_extractor(
audios,
padding="max_length",
return_tensors="pt",
)
)
return mm_inputs
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
self._validate_messages(messages, images, videos, audios)
messages = deepcopy(messages)
boi_token: str = getattr(processor, "boi_token")
eoi_token: str = getattr(processor, "eoi_token")
boa_token: str = getattr(processor, "boa_token")
eoa_token: str = getattr(processor, "eoa_token")
image_token: str = getattr(processor, "image_token")
video_token: str = getattr(processor, "video_token")
audio_token: str = getattr(processor, "audio_token")
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
num_image_soft_tokens: list[int] = list(
mm_inputs.get("num_soft_tokens_per_image", [getattr(processor, "image_seq_length", 256)] * len(images))
)
num_video_soft_tokens: list[int] = list(mm_inputs.get("num_soft_tokens_per_video", [1] * len(videos)))
video_metadata = mm_inputs.get("video_metadata", [])
else:
num_image_soft_tokens = [1] * len(images)
num_video_soft_tokens = [1] * len(videos)
video_metadata = [None] * len(videos)
audio_iter = iter(audios)
image_iter = iter(num_image_soft_tokens)
video_iter = iter(zip(num_video_soft_tokens, video_metadata))
for message in messages:
content = message["content"]
while IMAGE_PLACEHOLDER in content:
n = next(image_iter)
content = content.replace(IMAGE_PLACEHOLDER, f"{boi_token}{image_token * n}{eoi_token}", 1)
while VIDEO_PLACEHOLDER in content:
num_soft_tokens_per_frame, metadata = next(video_iter)
if self.expand_mm_tokens:
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs]
video_str = " ".join(frame_strs)
else:
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
content = content.replace(VIDEO_PLACEHOLDER, video_str, 1)
while AUDIO_PLACEHOLDER in content:
current_audio = next(audio_iter)
if self.expand_mm_tokens:
num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate)
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
else:
audio_str = f"{boa_token}{audio_token}{eoa_token}"
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
message["content"] = content
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
# Pop metadata keys that must not be passed to the model.
for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata",
"_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"):
mm_inputs.pop(key, None)
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
return mm_inputs
@dataclass @dataclass
class InternVLPlugin(BasePlugin): class InternVLPlugin(BasePlugin):
@override @override
@@ -1489,10 +1677,11 @@ class Qwen2VLPlugin(BasePlugin):
@override @override
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput": def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
results, fps_per_video, durations = [], [], [] results, fps_per_video, durations, frames_indices = [], [], [], []
for video in videos: for video in videos:
frames: list[ImageObject] = [] frames: list[ImageObject] = []
if _check_video_is_nested_images(video): if _check_video_is_nested_images(video):
# we assume already sample frames from videos
for frame in video: for frame in video:
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame): if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
raise ValueError("Invalid image found in video frames.") raise ValueError("Invalid image found in video frames.")
@@ -1500,10 +1689,14 @@ class Qwen2VLPlugin(BasePlugin):
frames = video frames = video
fps_per_video.append(kwargs.get("video_fps", 2.0)) fps_per_video.append(kwargs.get("video_fps", 2.0))
durations.append(len(frames) / kwargs.get("video_fps", 2.0)) durations.append(len(frames) / kwargs.get("video_fps", 2.0))
frames_indices.append(list(range(len(frames))))
else: else:
container = av.open(video, "r") container = av.open(video, "r")
video_stream = next(stream for stream in container.streams if stream.type == "video") video_stream = next(stream for stream in container.streams if stream.type == "video")
sample_indices = self._get_video_sample_indices(video_stream, **kwargs) sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
original_fps = float(video_stream.average_rate)
# for qwen3vl video timestamp calculation
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
container.seek(0) container.seek(0)
for frame_idx, frame in enumerate(container.decode(video_stream)): for frame_idx, frame in enumerate(container.decode(video_stream)):
if frame_idx in sample_indices: if frame_idx in sample_indices:
@@ -1522,7 +1715,7 @@ class Qwen2VLPlugin(BasePlugin):
frames = self._regularize_images(frames, **kwargs)["images"] frames = self._regularize_images(frames, **kwargs)["images"]
results.append(frames) results.append(frames)
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations} return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
@override @override
def _get_mm_inputs( def _get_mm_inputs(
@@ -1637,8 +1830,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_maxlen=getattr(processor, "video_maxlen", 128), video_maxlen=getattr(processor, "video_maxlen", 128),
) )
video_metadata = [ video_metadata = [
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)} {"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
for video, duration in zip(videos["videos"], videos["durations"]) for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
] ]
mm_inputs.update( mm_inputs.update(
video_processor( video_processor(
@@ -1646,6 +1839,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
video_metadata=video_metadata, video_metadata=video_metadata,
fps=getattr(processor, "video_fps", 2.0), fps=getattr(processor, "video_fps", 2.0),
return_metadata=True, return_metadata=True,
do_sample_frames=False, # avoid changing frames_indices
) )
) )
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
@@ -1677,7 +1871,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
image_grid_thw = mm_inputs.get("image_grid_thw", []) image_grid_thw = mm_inputs.get("image_grid_thw", [])
video_grid_thw = mm_inputs.get("video_grid_thw", []) video_grid_thw = mm_inputs.get("video_grid_thw", [])
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
video_metadata = mm_inputs.get("video_metadata", {}) video_metadata = mm_inputs.get("video_metadata", [])
else: else:
image_grid_thw = [None] * len(images) image_grid_thw = [None] * len(images)
@@ -2200,8 +2394,9 @@ PLUGINS = {
"base": BasePlugin, "base": BasePlugin,
"ernie_vl": ErnieVLPlugin, "ernie_vl": ErnieVLPlugin,
"gemma3": Gemma3Plugin, "gemma3": Gemma3Plugin,
"glm4v": GLM4VPlugin,
"gemma3n": Gemma3nPlugin, "gemma3n": Gemma3nPlugin,
"gemma4": Gemma4Plugin,
"glm4v": GLM4VPlugin,
"intern_vl": InternVLPlugin, "intern_vl": InternVLPlugin,
"kimi_vl": KimiVLPlugin, "kimi_vl": KimiVLPlugin,
"llama4": Llama4Plugin, "llama4": Llama4Plugin,

View File

@@ -997,6 +997,55 @@ register_template(
) )
register_template(
name="gemma4",
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
format_observation=StringFormatter(
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
), # seem not consistent with the chattemplate
format_tools=ToolFormatter(tool_format="gemma4"),
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<turn|>"],
default_system="You are a helpful assistant.", # important for thinking
thought_words=("<|channel>thought\n", "<channel|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(
"gemma4",
image_token="<|image|>",
video_token="<|video|>",
),
template_class=ReasoningTemplate,
)
register_template(
name="gemma4n",
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
format_observation=StringFormatter(
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
),
format_tools=ToolFormatter(tool_format="gemma4"),
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<turn|>"],
default_system="You are a helpful assistant.", # important for thinking
thought_words=("<|channel>thought\n", "<channel|>"),
replace_eos=True,
mm_plugin=get_mm_plugin(
"gemma4",
image_token="<|image|>",
video_token="<|video|>",
audio_token="<|audio|>",
),
template_class=ReasoningTemplate,
)
register_template( register_template(
name="glm4", name="glm4",
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),

View File

@@ -209,6 +209,164 @@ class DefaultToolUtils(ToolUtils):
return results return results
class Gemma4ToolUtils(ToolUtils):
r"""Gemma-4 tool using template."""
@override
@staticmethod
def tool_formatter(tools: list[dict[str, Any]]) -> str:
def _format_parameters(properties: dict[str, Any]) -> str:
parts: list[str] = []
for name, schema in properties.items():
item_parts: list[str] = []
if schema.get("description"):
item_parts.append(f'description:<|"|>{schema["description"]}<|"|>')
if schema.get("type"):
item_parts.append(f'type:<|"|>{str(schema["type"]).upper()}<|"|>')
parts.append(f"{name}:{{{','.join(item_parts)}}}")
return ",".join(parts)
declarations: list[str] = []
for tool in tools:
function_data = tool.get("function", tool) if tool.get("type") == "function" else tool
declaration = (
f"declaration:{function_data['name']}"
+ "{"
+ f'description:<|"|>{function_data.get("description", "")}<|"|>'
)
params = function_data.get("parameters")
if params:
param_parts: list[str] = []
if params.get("properties"):
param_parts.append(f"properties:{{{_format_parameters(params['properties'])}}}")
if params.get("required"):
required_text = ",".join(f'<|"|>{item}<|"|>' for item in params["required"])
param_parts.append(f"required:[{required_text}]")
if params.get("type"):
param_parts.append(f'type:<|"|>{str(params["type"]).upper()}<|"|>')
declaration += f",parameters:{{{','.join(param_parts)}}}"
response_declaration = function_data.get("response")
if response_declaration:
response_parts: list[str] = []
if response_declaration.get("description"):
response_parts.append(f'description:<|"|>{response_declaration["description"]}<|"|>')
response_type = str(response_declaration.get("type", "")).upper()
if response_type == "OBJECT":
response_parts.append(f'type:<|"|>{response_type}<|"|>')
declaration += f",response:{{{','.join(response_parts)}}}"
declarations.append(declaration + "}")
return "\n".join(declarations)
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
regex = re.compile(r"<\|tool_call\>call:([^{\s]+)\{(.*?)\}<tool_call\|>", re.DOTALL)
matches = re.findall(regex, content)
if not matches:
return content
def _parse_arguments(arg_text: str) -> Any:
text = arg_text.strip()
if not text:
return {}
# `function_formatter` writes dict arguments as `k:v,...` inside `{...}`.
# The extractor captures only the inner text, so re-wrap it to parse as JSON object.
object_like_text = "{" + text + "}"
# Convert Gemma string markers (<|"|>value<|"|>) to valid JSON strings.
normalized = re.sub(
r"<\|\"\|\>(.*?)<\|\"\|\>",
lambda m: json.dumps(m.group(1), ensure_ascii=False),
object_like_text,
flags=re.DOTALL,
)
# Quote unquoted object keys so the payload can be parsed by json.loads.
normalized = re.sub(r'(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)', r'\1"\2"\3', normalized)
try:
return json.loads(normalized)
except json.JSONDecodeError:
pass
try:
return json.loads(text)
except json.JSONDecodeError:
return text
results: list[FunctionCall] = []
for name, arg_block in matches:
parsed_arguments = _parse_arguments(arg_block)
if isinstance(parsed_arguments, str):
arguments = parsed_arguments
else:
arguments = json.dumps(parsed_arguments, ensure_ascii=False)
results.append(FunctionCall(name.strip(), arguments))
return results
@override
@staticmethod
def function_formatter(functions: list["FunctionCall"]) -> str:
def _format_argument(argument: Any, escape_keys: bool = True) -> str:
if isinstance(argument, str):
return f'<|"|>{argument}<|"|>'
if isinstance(argument, bool):
return "true" if argument else "false"
if isinstance(argument, dict):
items: list[str] = []
for key in sorted(argument.keys()):
formatted_key = f'<|"|>{key}<|"|>' if escape_keys else str(key)
formatted_value = _format_argument(argument[key], escape_keys=escape_keys)
items.append(f"{formatted_key}:{formatted_value}")
return "{" + ",".join(items) + "}"
if isinstance(argument, (list, tuple)):
return "[" + ",".join(_format_argument(item, escape_keys=escape_keys) for item in argument) + "]"
if argument is None:
return "null"
return str(argument)
function_texts: list[str] = []
for function in functions:
name = function.name
raw_arguments = function.arguments
try:
parsed_arguments = json.loads(raw_arguments)
except (TypeError, json.JSONDecodeError):
parsed_arguments = raw_arguments
call_text = f"<|tool_call>call:{name}" + "{"
if isinstance(parsed_arguments, dict):
args_text = []
for key in sorted(parsed_arguments.keys()):
value_text = _format_argument(parsed_arguments[key], escape_keys=False)
args_text.append(f"{key}:{value_text}")
call_text += ",".join(args_text)
elif isinstance(parsed_arguments, str):
call_text += parsed_arguments
else:
call_text += _format_argument(parsed_arguments, escape_keys=False)
call_text += "}<tool_call|>"
function_texts.append(call_text)
return "".join(function_texts)
class GLM4ToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils):
r"""GLM-4 tool using template.""" r"""GLM-4 tool using template."""
@@ -361,6 +519,8 @@ class MiniMaxM2ToolUtils(ToolUtils):
prompt += "\n</invoke>" prompt += "\n</invoke>"
function_texts.append(prompt) function_texts.append(prompt)
return "\n".join(function_texts)
@override @override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
@@ -721,6 +881,7 @@ class LFM2ToolUtils(ToolUtils):
TOOLS = { TOOLS = {
"default": DefaultToolUtils(), "default": DefaultToolUtils(),
"gemma4": Gemma4ToolUtils(),
"glm4": GLM4ToolUtils(), "glm4": GLM4ToolUtils(),
"llama3": Llama3ToolUtils(), "llama3": Llama3ToolUtils(),
"lfm2": LFM2ToolUtils(), "lfm2": LFM2ToolUtils(),

View File

@@ -865,6 +865,34 @@ register_model_group(
) )
register_model_group(
models={
"Gemma-4-26B-A4B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-26B-A4B-it",
},
"Gemma-4-31B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-31B-it",
},
},
template="gemma4",
multimodal=True,
)
register_model_group(
models={
"Gemma-4-E2B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-E2B-it",
},
"Gemma-4-E4B-Thinking": {
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
},
},
template="gemma4n",
multimodal=True,
)
register_model_group( register_model_group(
models={ models={
"GLM-4-9B": { "GLM-4-9B": {

View File

@@ -70,6 +70,10 @@ def is_matplotlib_available():
return _is_package_available("matplotlib") return _is_package_available("matplotlib")
def is_hyper_parallel_available():
return _is_package_available("hyper_parallel")
def is_mcore_adapter_available(): def is_mcore_adapter_available():
return _is_package_available("mcore_adapter") return _is_package_available("mcore_adapter")

View File

@@ -482,6 +482,24 @@ class FinetuningArguments(
) )
}, },
) )
use_hyper_parallel: bool = field(
default=False,
metadata={
"help": (
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
"Only supported for the 'sft' stage with full fine-tuning."
)
},
)
hyper_parallel_args: str | None = field(
default=None,
metadata={
"help": (
"Path to a JSON file containing HyperParallel strategy arguments "
"(e.g., tp_size, param_dtype). Used when use_hyper_parallel=True."
)
},
)
use_muon: bool = field( use_muon: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to use the Muon optimizer."}, metadata={"help": "Whether or not to use the Muon optimizer."},

View File

@@ -125,7 +125,7 @@ def _setup_freeze_tuning(
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS: if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key) trainable_layers.extend(COMPOSITE_MODELS[model_type].projector_keys)
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
for name, param in model.named_parameters(): for name, param in model.named_parameters():

View File

@@ -45,7 +45,7 @@ def apply_liger_kernel(
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
elif model_type == "gemma3_text": elif model_type == "gemma3_text":
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
elif model_type == "glm4": elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
elif model_type == "glm4v": elif model_type == "glm4v":
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel

View File

@@ -35,7 +35,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
forbidden_modules.add("output") forbidden_modules.add("output")
if model_type in COMPOSITE_MODELS: if model_type in COMPOSITE_MODELS:
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key) forbidden_modules.update(COMPOSITE_MODELS[model_type].projector_keys)
if freeze_vision_tower and model_type in COMPOSITE_MODELS: if freeze_vision_tower and model_type in COMPOSITE_MODELS:
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys) forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)

View File

@@ -147,6 +147,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock]) _set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
if model_type == "qwen3_5_moe":
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
_set_z3_leaf_modules(model, [Qwen3_5MoeSparseMoeBlock])
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
if not is_trainable or not model_args.moe_aux_loss_coef: if not is_trainable or not model_args.moe_aux_loss_coef:

View File

@@ -39,16 +39,26 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
@dataclass @dataclass
class CompositeModel: class CompositeModel:
model_type: str model_type: str
projector_key: str projector_keys: list[str]
vision_model_keys: list[str] vision_model_keys: list[str]
language_model_keys: list[str] language_model_keys: list[str]
lora_conflict_keys: list[str] lora_conflict_keys: list[str]
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
for key in self.projector_key.split("."):
module = getattr(module, key)
return module def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
mm_projectors: list[torch.nn.Module] = []
for projector_key in self.projector_keys:
project_module = module
for key in projector_key.split("."):
project_module = getattr(project_module, key, None)
if project_module is None: # i,e gemma4 bigger one, there is no embed_audio
logger.warning_rank0(f"Projector key {projector_key} not found in module {module.__class__.__name__}.")
break
if project_module is not None:
mm_projectors.append(project_module)
return mm_projectors
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {} COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
@@ -56,7 +66,7 @@ COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
def _register_composite_model( def _register_composite_model(
model_type: str, model_type: str,
projector_key: Optional[str] = None, projector_keys: list[str] | None = None,
vision_model_keys: Optional[list[str]] = None, vision_model_keys: Optional[list[str]] = None,
language_model_keys: Optional[list[str]] = None, language_model_keys: Optional[list[str]] = None,
lora_conflict_keys: Optional[list[str]] = None, lora_conflict_keys: Optional[list[str]] = None,
@@ -65,7 +75,7 @@ def _register_composite_model(
Args: Args:
model_type: model type model_type: model type
projector_key: multi_modal_projector projector_keys: multi_modal_projector
vision_model_keys: vision_tower vision_model_keys: vision_tower
language_model_keys: language_model language_model_keys: language_model
lora_conflict_keys: None lora_conflict_keys: None
@@ -73,7 +83,7 @@ def _register_composite_model(
""" """
COMPOSITE_MODELS[model_type] = CompositeModel( COMPOSITE_MODELS[model_type] = CompositeModel(
model_type=model_type, model_type=model_type,
projector_key=projector_key or "multi_modal_projector", projector_keys=projector_keys or ["multi_modal_projector"],
vision_model_keys=vision_model_keys or ["vision_tower"], vision_model_keys=vision_model_keys or ["vision_tower"],
language_model_keys=language_model_keys or ["language_model", "lm_head"], language_model_keys=language_model_keys or ["language_model", "lm_head"],
lora_conflict_keys=lora_conflict_keys or [], lora_conflict_keys=lora_conflict_keys or [],
@@ -136,12 +146,16 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
if getattr(model, "quantization_method", None): if getattr(model, "quantization_method", None):
model_type = getattr(model.config, "model_type", None) model_type = getattr(model.config, "model_type", None)
if model_type in COMPOSITE_MODELS: if model_type in COMPOSITE_MODELS:
mm_projector = COMPOSITE_MODELS[model_type].get_projector(model) mm_projectors = COMPOSITE_MODELS[model_type].get_projectors(model)
else: else:
return return
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.") logger.info_rank0(
mm_projector.register_forward_hook(_mm_projector_forward_post_hook) f"Casting multimodal projector outputs in {model_args.compute_dtype}: "
f"{COMPOSITE_MODELS[model_type].projector_keys}."
)
for mm_projector in mm_projectors:
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
def configure_visual_model(config: "PretrainedConfig") -> None: def configure_visual_model(config: "PretrainedConfig") -> None:
@@ -166,9 +180,9 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
forbidden_modules.update(vision_model_keys) forbidden_modules.update(vision_model_keys)
if finetuning_args.freeze_multi_modal_projector: if finetuning_args.freeze_multi_modal_projector:
projector_key = COMPOSITE_MODELS[model_type].projector_key projector_keys = COMPOSITE_MODELS[model_type].projector_keys
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.") logger.info_rank0(f"Set multi model projector not trainable: {projector_keys}.")
forbidden_modules.add(projector_key) forbidden_modules.update(projector_keys)
if finetuning_args.freeze_language_model: if finetuning_args.freeze_language_model:
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
@@ -200,7 +214,7 @@ def patch_target_modules(
_register_composite_model( _register_composite_model(
model_type="dots_ocr", model_type="dots_ocr",
projector_key="vision_tower.merger", projector_keys=["vision_tower.merger"],
vision_model_keys=["vision_tower"], vision_model_keys=["vision_tower"],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["merger"], lora_conflict_keys=["merger"],
@@ -219,10 +233,18 @@ _register_composite_model(
) )
_register_composite_model(
model_type="gemma4",
projector_keys=["model.embed_vision", "model.embed_audio"],
vision_model_keys=["vision_tower", "audio_tower"],
lora_conflict_keys=["per_layer_projection_norm"],
)
# copied from qwen2vl # copied from qwen2vl
_register_composite_model( _register_composite_model(
model_type="glm4v", model_type="glm4v",
projector_key="visual.merger", projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -231,7 +253,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="glm4v_moe", model_type="glm4v_moe",
projector_key="visual.merger", projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -240,7 +262,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="glm_ocr", model_type="glm_ocr",
projector_key="visual.merger", projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -257,7 +279,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="Keye", model_type="Keye",
projector_key="mlp_AR", projector_keys=["mlp_AR"],
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"], vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embedding"], lora_conflict_keys=["patch_embedding"],
@@ -292,7 +314,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="minicpmv", model_type="minicpmv",
projector_key="resampler", projector_keys=["resampler"],
vision_model_keys=["vpm"], vision_model_keys=["vpm"],
language_model_keys=["llm"], language_model_keys=["llm"],
) )
@@ -300,7 +322,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="minicpmo", model_type="minicpmo",
projector_key="resampler", projector_keys=["resampler"],
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"], vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
language_model_keys=["llm"], language_model_keys=["llm"],
lora_conflict_keys=["audio_projection_layer"], lora_conflict_keys=["audio_projection_layer"],
@@ -309,7 +331,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="mistral3", model_type="mistral3",
projector_key="model.multi_modal_projector", projector_keys=["model.multi_modal_projector"],
) )
@@ -332,7 +354,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen2_5_omni_thinker", model_type="qwen2_5_omni_thinker",
projector_key="visual.merger", projector_keys=["visual.merger", "audio_tower.proj"],
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"], vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
language_model_keys=["model", "lm_head"], language_model_keys=["model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -341,7 +363,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen2_vl", model_type="qwen2_vl",
projector_key="visual.merger", projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -350,7 +372,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen2_5_vl", model_type="qwen2_5_vl",
projector_key="visual.merger", projector_keys=["visual.merger"],
vision_model_keys=["visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -359,7 +381,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_vl", model_type="qwen3_vl",
projector_key="visual.merger", projector_keys=["visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -368,7 +390,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_vl_moe", model_type="qwen3_vl_moe",
projector_key="visual.merger", projector_keys=["visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -377,7 +399,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_omni_moe_thinker", model_type="qwen3_omni_moe_thinker",
projector_key="visual.merger", projector_keys=["visual.merger", "audio_tower.proj"],
vision_model_keys=[ vision_model_keys=[
"visual.pos_embed", "visual.pos_embed",
"visual.patch_embed", "visual.patch_embed",
@@ -392,7 +414,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_5", model_type="qwen3_5",
projector_key="model.visual.merger", projector_keys=["model.visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],
@@ -401,7 +423,7 @@ _register_composite_model(
_register_composite_model( _register_composite_model(
model_type="qwen3_5_moe", model_type="qwen3_5_moe",
projector_key="model.visual.merger", projector_keys=["model.visual.merger"],
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"], vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
language_model_keys=["language_model", "lm_head"], language_model_keys=["language_model", "lm_head"],
lora_conflict_keys=["patch_embed"], lora_conflict_keys=["patch_embed"],

View File

@@ -0,0 +1,18 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .workflow import run_sft
__all__ = ["run_sft"]

View File

@@ -0,0 +1,183 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from ...extras.misc import calculate_tps
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
from ...extras.ploting import plot_loss
from ...model import load_model, load_tokenizer
from ..callbacks import SaveProcessorCallback
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
logger = get_logger(__name__)
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
):
if not is_hyper_parallel_available():
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
HyperParallelArguments,
HyperParallelTrainer,
)
tokenizer_module = load_tokenizer(model_args)
tokenizer = tokenizer_module["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
ref_model = None
if finetuning_args.use_asft_loss:
ref_model = create_ref_model(model_args, finetuning_args)
data_collator = SFTDataCollatorWith4DAttentionMask(
template=template,
model=model if not training_args.predict_with_generate else None,
pad_to_multiple_of=8 if training_args.do_train else None,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
block_diag_attn=model_args.block_diag_attn,
attn_implementation=getattr(model.config, "_attn_implementation", None),
compute_dtype=model_args.compute_dtype,
**tokenizer_module,
)
# Metric utils
metric_module = {}
if training_args.predict_with_generate:
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
elif finetuning_args.compute_accuracy:
metric_module["compute_metrics"] = ComputeAccuracy()
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
if is_transformers_version_greater_than("4.58.0"):
extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
if not isinstance(extra_ids, list):
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
string_tokens = [str(t) for t in extra_special_tokens]
extra_ids = tokenizer.convert_tokens_to_ids(string_tokens)
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
gen_kwargs["eos_token_id"] = list(dict.fromkeys(all_eos_ids))
else:
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
callbacks = list(callbacks or [])
processor = tokenizer_module.get("processor")
if processor is not None:
callbacks.append(SaveProcessorCallback(processor))
compute_loss_func = None
if finetuning_args.use_dft_loss:
compute_loss_func = dft_loss_func
elif finetuning_args.use_eaft_loss:
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
)
elif finetuning_args.use_asft_loss:
from functools import partial
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
trainer = HyperParallelTrainer(
hp_args=hp_args,
model=model,
args=training_args,
finetuning_args=finetuning_args,
data_collator=data_collator,
callbacks=callbacks,
gen_kwargs=gen_kwargs,
ref_model=ref_model,
compute_loss_func=compute_loss_func,
**dataset_module,
**tokenizer_module,
**metric_module,
)
if finetuning_args.use_badam:
from types import MethodType
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
trainer.add_callback(BAdamCallback)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.save_model()
if finetuning_args.include_effective_tokens_per_second:
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
dataset_module["train_dataset"], train_result.metrics, stage="sft"
)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
keys = ["loss"]
if isinstance(dataset_module.get("eval_dataset"), dict):
keys += sum(
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()],
[],
)
else:
keys += ["eval_loss", "eval_accuracy"]
plot_loss(training_args.output_dir, keys=keys)
if training_args.predict_with_generate:
tokenizer.padding_side = "left"
# Evaluation
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
# Predict
if training_args.do_predict:
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
# Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)

View File

@@ -24,7 +24,12 @@ from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than from ..extras.packages import (
is_hyper_parallel_available,
is_mcore_adapter_available,
is_ray_available,
is_transformers_version_greater_than,
)
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
@@ -71,7 +76,16 @@ def _training_function(config: dict[str, Any]) -> None:
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca: if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
if not is_hyper_parallel_available():
raise ImportError(
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
)
from .hyper_parallel import run_sft as run_sft_hp
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
if not is_mcore_adapter_available(): if not is_mcore_adapter_available():
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.") raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
if finetuning_args.stage == "pt": if finetuning_args.stage == "pt":

View File

@@ -85,6 +85,10 @@ class TrainingArguments:
default=42, default=42,
metadata={"help": "Random seed that will be set at the beginning of training."}, metadata={"help": "Random seed that will be set at the beginning of training."},
) )
logging_steps: int = field(
default=1,
metadata={"help": "Log metrics every N optimizer steps."},
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
self.dist_config = get_plugin_config(self.dist_config) self.dist_config = get_plugin_config(self.dist_config)

View File

@@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp
from ..accelerator.interface import Dim, DistributedInterface from ..accelerator.interface import Dim, DistributedInterface
from ..config import TrainingArguments from ..config import TrainingArguments
from ..utils import logging from ..utils import logging
from ..utils.callbacks import (
CallbackHandler,
LoggingCallback,
TrainerCallback,
TrainerState,
)
from ..utils.helper import compute_valid_tokens from ..utils.helper import compute_valid_tokens
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
from .utils.batching import BatchGenerator from .utils.batching import BatchGenerator
@@ -52,6 +58,7 @@ class BaseTrainer:
model: HFModel, model: HFModel,
renderer: Renderer, renderer: Renderer,
train_dataset: TorchDataset, train_dataset: TorchDataset,
callbacks: list[TrainerCallback] | None = None,
) -> None: ) -> None:
self.args = args self.args = args
self.model = model self.model = model
@@ -64,6 +71,7 @@ class BaseTrainer:
# cached variables # cached variables
self.device = DistributedInterface().current_device self.device = DistributedInterface().current_device
self.dp_size = DistributedInterface().get_world_size(Dim.DP) self.dp_size = DistributedInterface().get_world_size(Dim.DP)
self.cp_size = DistributedInterface().get_world_size(Dim.CP)
self.model_input_names = self.renderer.processor.model_input_names self.model_input_names = self.renderer.processor.model_input_names
self._create_batch_generator() self._create_batch_generator()
@@ -99,6 +107,29 @@ class BaseTrainer:
self._init_optimizer() self._init_optimizer()
self._init_lr_scheduler() self._init_lr_scheduler()
# Callbacks
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
for cb in callbacks or []:
self.callback_handler.add_callback(cb)
# Callbacks: TrainerState tracks progress across the full run.
self.state = TrainerState(num_training_steps=self.num_training_steps)
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
if model.config.model_type == "qwen3_5":
raise RuntimeError(
"Sequence parallel is not supported for qwen3.5 model due to its different attention implementation, which will be supported in the future."
)
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
if model.config._attn_implementation != "flash_attention_2":
logger.warning_rank0(
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
)
model.config._attn_implementation = "flash_attention_2"
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
def _create_batch_generator(self) -> None: def _create_batch_generator(self) -> None:
self.train_batch_generator = BatchGenerator( self.train_batch_generator = BatchGenerator(
dataset=self.train_dataset, dataset=self.train_dataset,
@@ -157,7 +188,7 @@ class BaseTrainer:
""" """
batch_size, _ = batch["labels"].shape batch_size, _ = batch["labels"].shape
model_inputs = { model_inputs = {
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names k: v.to(self.device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor)
} }
labels = batch["labels"].to(self.device, non_blocking=True) labels = batch["labels"].to(self.device, non_blocking=True)
outputs: ModelOutput = model(**model_inputs) outputs: ModelOutput = model(**model_inputs)
@@ -174,16 +205,31 @@ class BaseTrainer:
def fit(self) -> None: def fit(self) -> None:
"""Train the model.""" """Train the model."""
self.model.train() self.model.train()
self.callback_handler.on_train_begin(self.args, self.state)
for epoch in range(self.args.num_train_epochs): for epoch in range(self.args.num_train_epochs):
self.state.epoch = epoch
self.train_batch_generator.set_epoch(epoch) self.train_batch_generator.set_epoch(epoch)
self.callback_handler.on_epoch_begin(self.args, self.state)
for micro_batches in self.train_batch_generator: for micro_batches in self.train_batch_generator:
self.global_step += 1 self.global_step += 1
self.state.global_step = self.global_step
self.callback_handler.on_step_begin(self.args, self.state)
step_loss = 0 step_loss = 0
step_valid_tokens = compute_valid_tokens(micro_batches) step_valid_tokens = compute_valid_tokens(micro_batches)
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM) step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
num_micro = len(micro_batches) num_micro = len(micro_batches)
for i, micro_batch in enumerate(micro_batches): for i, micro_batch in enumerate(micro_batches):
loss = self.compute_loss(micro_batch) if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
from ..plugins.model_plugins.parallelization.sequence_parallel import (
SequenceParallelLossPlugin,
)
loss = SequenceParallelLossPlugin("sequence_parallel_loss")(self.model, micro_batch)
else:
loss = self.compute_loss(micro_batch)
mini_step_valid_tokens = compute_valid_tokens([micro_batch]) mini_step_valid_tokens = compute_valid_tokens([micro_batch])
# fsdp uses mean reduction so we need to scale the loss by dp_size # fsdp uses mean reduction so we need to scale the loss by dp_size
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6) loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
@@ -200,7 +246,24 @@ class BaseTrainer:
# deepspeed: engine.step() already ran inside backward at the sync boundary # deepspeed: engine.step() already ran inside backward at the sync boundary
grad_norm = self._deepspeed_engine.get_grad_norm() grad_norm = self._deepspeed_engine.get_grad_norm()
else: else:
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item() if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
parameters = self.model.parameters()
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
parameters = list(parameters)
grads = [p.grad for p in parameters if p.grad is not None]
grad_norm = _get_total_norm(grads)
grad_norm = grad_norm.to(self.device)
_clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm)
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
grad_norm = grad_norm.full_tensor().item()
else:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.args.max_grad_norm
).item()
# isfinite(): argument 'input' (position 1) must be Tensor, not float # isfinite(): argument 'input' (position 1) must be Tensor, not float
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType] if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
@@ -213,14 +276,41 @@ class BaseTrainer:
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm]) step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
DistributedInterface().sync() DistributedInterface().sync()
if DistributedInterface().get_rank() == 0:
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}") # Update state with step metrics
current_lr = (
self.lr_scheduler.get_last_lr()[0]
if hasattr(self.lr_scheduler, "get_last_lr")
else self.args.learning_rate
)
self.state.loss = step_loss
self.state.grad_norm = grad_norm
self.state.learning_rate = current_lr
self.callback_handler.on_step_end(self.args, self.state)
# Logging: trainer decides when to log
if self.global_step % self.args.logging_steps == 0:
logs = {
"epoch": epoch,
"step": self.global_step,
"loss": step_loss,
"grad_norm": grad_norm,
"learning_rate": current_lr,
}
self.callback_handler.on_log(self.args, self.state, logs)
# Check if max_steps is reached # Check if max_steps is reached
if self.global_step >= self.num_training_steps: if self.global_step >= self.num_training_steps:
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.") logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
self.callback_handler.on_epoch_end(self.args, self.state)
self.callback_handler.on_train_end(self.args, self.state)
return return
self.callback_handler.on_epoch_end(self.args, self.state)
self.callback_handler.on_train_end(self.args, self.state)
def save_model(self) -> None: def save_model(self) -> None:
"""Save the model.""" """Save the model."""
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"): if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
@@ -234,3 +324,5 @@ class BaseTrainer:
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB") model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB") self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
logger.info_rank0(f"Model saved to {self.args.output_dir}") logger.info_rank0(f"Model saved to {self.args.output_dir}")
self.callback_handler.on_save(self.args, self.state)

View File

@@ -140,6 +140,9 @@ class ModelEngine:
**init_kwargs, **init_kwargs,
) )
init_mode = self.args.init_config.name if self.args.init_config is not None else "init_on_default"
model._init_mode = init_mode
if self.args.peft_config is None: if self.args.peft_config is None:
if self.is_train: if self.is_train:
logger.info_rank0("Fine-tuning mode: full tuning") logger.info_rank0("Fine-tuning mode: full tuning")
@@ -147,6 +150,9 @@ class ModelEngine:
else: else:
logger.info_rank0("Inference the original model") logger.info_rank0("Inference the original model")
else: else:
if self.args.peft_config.name == "lora" and init_mode == "init_on_meta":
raise ValueError("Currently lora stage does not support loading model by meta.")
from ..plugins.model_plugins.peft import PeftPlugin from ..plugins.model_plugins.peft import PeftPlugin
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train) model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)

View File

@@ -146,6 +146,8 @@ class Renderer:
for sample in samples: for sample in samples:
if "messages" in sample: if "messages" in sample:
model_input = self.render_messages(sample["messages"], sample.get("tools")) model_input = self.render_messages(sample["messages"], sample.get("tools"))
if "position_ids" not in model_input:
model_input["position_ids"] = list(range(1, len(model_input["input_ids"]) + 1))
elif "chosen_messages" in sample and "rejected_messages" in sample: elif "chosen_messages" in sample and "rejected_messages" in sample:
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools")) chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools")) rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))

View File

@@ -0,0 +1,59 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's verl library.
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
import torch.distributed as dist
from torch import Tensor
def all_to_all_tensor(
local_input: Tensor,
scatter_dim: int,
gather_dim: int,
group: Optional[dist.ProcessGroup] = None,
):
seq_world_size = dist.get_world_size(group)
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
dist.all_to_all(output_list, input_list, group=group)
return torch.cat(output_list, dim=gather_dim).contiguous()
class SeqAllToAll4D(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
group: dist.ProcessGroup,
local_input: Tensor,
scatter_dim: int,
gather_dim: int,
) -> Tensor:
ctx.group = group
ctx.scatter_dim = scatter_dim
ctx.gather_dim = gather_dim
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group)
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]:
return (
None,
all_to_all_tensor(grad_output[0], ctx.gather_dim, ctx.scatter_dim, ctx.group),
None,
None,
)

View File

@@ -0,0 +1,199 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from functools import partial
import torch
import torch.distributed as dist
import torch.nn.functional as F
import transformers
from ....accelerator.interface import Dim, DistributedInterface
from ....utils import logging
from ....utils.plugin import BasePlugin
from ....utils.types import ModelOutput
from .ulysses import (
UlyssesAttention,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_rank,
get_ulysses_sequence_parallel_world_size,
set_ulysses_sequence_parallel_group,
)
logger = logging.get_logger(__name__)
class SequenceParallelModelPlugin(BasePlugin):
def __call__(self, model, model_args):
return super().__call__(model, model_args)
class SequenceParallelLossPlugin(BasePlugin):
def __call__(self, model, inputs, *args, **kwargs):
return super().__call__(model, inputs, *args, **kwargs)
def new_flash_attn_forward(
query_states,
key_states,
value_states,
attention_mask,
sequence_parallel_size=1,
dropout=0,
deterministic=False,
is_causal=True,
group=None,
mode="ulysses",
attn_fn=None,
target_dtype=None,
**kwargs,
):
if mode == "ulysses":
dist_attn = UlyssesAttention(sequence_process_group=group, attn_fn=attn_fn)
attn_output = dist_attn(
query_states,
key_states,
value_states,
attention_mask,
query_length=query_states.shape[1] * sequence_parallel_size,
deterministic=deterministic,
dropout_p=dropout,
causal=is_causal,
position_ids=kwargs.get("position_ids", None),
target_dtype=target_dtype,
)
else:
raise NotImplementedError("Other sequence parallel modes are to be implemented.")
return attn_output
@SequenceParallelModelPlugin("ulysses").register()
def apply_sequence_parallel(model, model_args):
# Replace _flash_attention_forward with new_flash_attn_forward
module = sys.modules[model.__module__]
cp_size = model_args.get("cp_size", 1)
set_ulysses_sequence_parallel_group(DistributedInterface().get_group(Dim.CP))
try:
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_attention_heads
except AttributeError:
num_attention_heads, num_key_value_heads = (
model.config.text_config.num_attention_heads,
model.config.text_config.num_key_value_heads,
)
assert num_attention_heads % cp_size == 0, "num_attention_heads must be divisible by cp_size"
assert num_key_value_heads % cp_size == 0 or cp_size % num_key_value_heads == 0, (
"num_key_value_heads must be divisible by cp_size"
)
origin_attn = transformers.modeling_flash_attention_utils._flash_attention_forward
new_flash_attention_forward = partial(
new_flash_attn_forward,
group=get_ulysses_sequence_parallel_group(),
mode="ulysses",
attn_fn=origin_attn,
sequence_parallel_size=cp_size,
)
for module_name, module in list(sys.modules.items()):
try:
if (
hasattr(module, "__file__")
and "transformers" in module.__file__
and getattr(module._flash_attention_forward, "__name__", "") == "_flash_attention_forward"
):
module._flash_attention_forward = new_flash_attention_forward
logger.info_rank0(
f"Replaced _flash_attention_forward in module {module_name} with new_flash_attn_forward for sequence parallel."
)
except (AttributeError, TypeError):
continue
def padding_and_split_data(data, device_mesh=None):
if device_mesh is not None:
cp_size = device_mesh["cp"].size()
cp_rank = device_mesh["cp"].get_local_rank()
cp_group = device_mesh["cp"].get_group()
for k, v in data.items():
if isinstance(v, torch.Tensor) and v.ndim > 1:
data_len = torch.tensor(v.shape[-1], device=v.device, dtype=torch.int64)
global_data_len = [torch.empty_like(data_len) for _ in range(cp_size)]
dist.all_gather(global_data_len, data_len, group=cp_group)
max_data_len = max(global_data_len)
pad_size = max_data_len - v.shape[-1] + (cp_size - max_data_len % cp_size) % cp_size
if k == "labels":
pad_value = -100
elif k == "loss_weights":
pad_value = 0.0
else:
pad_value = 0
pad_data = F.pad(v, (0, pad_size), value=pad_value)
data[k] = torch.chunk(pad_data, chunks=cp_size, dim=-1)[cp_rank].contiguous()
return data
@SequenceParallelLossPlugin("sequence_parallel_loss").register()
def sequence_parallel_loss(model, model_inputs):
device_mesh = DistributedInterface().get_device_mesh(Dim.CP)
model_inputs = {
k: v.to(dist.get_rank(), non_blocking=True) for k, v in model_inputs.items() if isinstance(v, torch.Tensor)
}
model_inputs = padding_and_split_data(model_inputs, device_mesh)
batch_size, _ = model_inputs["labels"].shape
outputs: ModelOutput = model(**model_inputs)
logits = outputs.logits.float()
labels = model_inputs["labels"]
cp_group = get_ulysses_sequence_parallel_group()
cp_world_size = get_ulysses_sequence_parallel_world_size(cp_group)
cp_rank = get_ulysses_sequence_parallel_rank(cp_group)
# use all_gather to collect labels from all sequence parallel processes
global_labels = [torch.empty_like(labels) for _ in range(cp_world_size)]
dist.all_gather(global_labels, labels, group=cp_group)
labels = torch.cat(global_labels, dim=1).contiguous()
shift_labels = labels[..., 1:].view(-1).contiguous()
shift_labels = F.pad(shift_labels, (0, 1), value=-100)
shift_labels = torch.chunk(shift_labels, chunks=cp_world_size, dim=-1)[cp_rank].contiguous()
# use all_gather to collect loss_weights from all sequence parallel processes
loss_weights = model_inputs["loss_weights"]
global_loss_weights = [torch.empty_like(loss_weights) for _ in range(cp_world_size)]
dist.all_gather(global_loss_weights, loss_weights, group=cp_group)
shift_loss_weights = torch.cat(global_loss_weights, dim=1).contiguous()
shift_loss_weights = shift_loss_weights[..., 1:].contiguous()
shift_logits = logits.view(shift_labels.size(0), -1).contiguous()
# use all_gather to collect log_probs from all sequence parallel processes
log_probs = -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
global_log_probs = dist.nn.all_gather(log_probs, group=cp_group)
global_log_probs = torch.cat(global_log_probs, dim=1).contiguous()
log_probs = global_log_probs[..., :-1].contiguous()
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
return loss

View File

@@ -0,0 +1,163 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
#
# This code is inspired by the Bytedance's verl library.
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
from .seq_comm import SeqAllToAll4D
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
"""Set ulysses sequence parallel process group."""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
"""Get ulysses sequence parallel process group."""
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
"""Get ulysses sequence parallel world size."""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_world_size(group) if group else 1
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
"""Get ulysses sequence parallel rank."""
group = get_ulysses_sequence_parallel_group() if group is None else group
return dist.get_rank(group) if group else 0
class UlyssesAttention(torch.nn.Module):
"""Initialization.
Arguments:
local_attention (Module): local attention with q,k,v
sequence_process_group (ProcessGroup): sequence parallel process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
attn_type (AttnType): attention type enum
"""
def __init__(
self,
sequence_process_group: dist.ProcessGroup = None,
scatter_idx: int = 2,
gather_idx: int = 1,
attn_fn: Optional[callable] = None,
) -> None:
super().__init__()
self.spg = sequence_process_group
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx
self.attn_fn = attn_fn
def forward(
self,
query: Tensor,
key: Tensor,
value: Tensor,
attention_mask: torch.Tensor,
query_length: int,
dropout_p=0.0,
softmax_scale=None,
position_ids: Optional[torch.Tensor] = None,
causal=True,
deterministic=False,
target_dtype=None,
*args: Any,
) -> Tensor:
"""Forward.
Arguments:
query (Tensor): query input to the layer
key (Tensor): key input to the layer
value (Tensor): value input to the layer
attention_mask (Tensor): attention mask for the layer
query_length (int): the length of the query sequence
dropout_p (float, optional): dropout probability. Defaults to 0.0.
softmax_scale (float, optional): scale factor for softmax. Defaults to None,
position_ids (torch.Tensor, optional): position ids for the attention. Defaults to None.
causal (bool, optional): whether to apply causal mask. Defaults to True.
deterministic (bool, optional): whether to apply dropout in deterministic way. Defaults to False.
target_dtype (torch.dtype, optional): target dtype for attention output. Defaults to None.
args: other args
Returns:
* output (Tensor): context output
"""
# TODO Merge three alltoall calls into one
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
# in shape : e.g., [s/p:h:]
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
# scatter 2, gather 1
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
if softmax_scale is None:
softmax_scale = q.shape[-1] ** -0.5
if attention_mask is None:
if position_ids is not None:
attention_mask = torch.ones_like(position_ids).to(torch.int64)
else:
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
else:
attention_mask = attention_mask.to(torch.int64)
global_attention_mask = [
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
]
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
attention_mask = torch.cat(global_attention_mask, dim=1)
context_layer = self.attn_fn(
q,
k,
v,
attention_mask,
query_length=query_length,
is_causal=causal,
dropout=dropout_p,
position_ids=position_ids,
softmax_scale=softmax_scale,
deterministic=deterministic,
target_dtype=target_dtype,
)
if isinstance(context_layer, tuple):
context_layer = context_layer[0]
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
# out e.g., [s/p::h]
return output

View File

@@ -150,9 +150,6 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
@PeftPlugin("lora").register() @PeftPlugin("lora").register()
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel: def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
if model.device.type == "meta":
raise ValueError("Currently lora stage does not support loading model by meta.")
adapter_name_or_path = config.get("adapter_name_or_path") adapter_name_or_path = config.get("adapter_name_or_path")
if adapter_name_or_path: if adapter_name_or_path:

View File

@@ -17,6 +17,7 @@ import gc
import os import os
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from peft.tuners.lora import LoraLayer from peft.tuners.lora import LoraLayer
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
@@ -84,10 +85,7 @@ class FSDP2Engine:
) )
if self.device_mesh is not None: if self.device_mesh is not None:
try: self.fsdp_mesh = self.device_mesh
self.fsdp_mesh = self.device_mesh["dp"]
except Exception:
self.fsdp_mesh = self.device_mesh
logger.info(f"Using Device Mesh: {self.fsdp_mesh}") logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
else: else:
@@ -244,23 +242,57 @@ class FSDP2Engine:
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers") logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
def shard_model(self, model: HFModel) -> HFModel: def shard_model(self, model: HFModel) -> HFModel:
if model.device.type == "meta": init_mode = getattr(model, "_init_mode", "init_on_default")
if init_mode == "init_on_rank0":
if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights()
if self.rank == 0:
logger.info("init_on_rank0 detected: sharding then scattering Rank 0 CPU weights.")
full_sd = {k: v.clone() for k, v in model.state_dict().items()}
else:
full_sd = {}
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
model = self.prepare_model(model)
device = get_current_accelerator()
model.to_empty(device=device)
# Scatter params from Rank 0 into all DTensor shards
# Broadcast the full state dict from the global rank-0 process to all ranks in this group.
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
set_model_state_dict(model, full_sd, options=options)
# Broadcast and restore non-persistent buffers
buffers_to_sync = [saved_buffers]
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
if self.rank == 0:
logger.info("init_on_rank0 sync complete.")
elif init_mode == "init_on_meta":
non_persistent_buffers = self._save_non_persistent_buffers(model) non_persistent_buffers = self._save_non_persistent_buffers(model)
if getattr(model.config, "tie_word_embeddings", None): if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights() model.tie_weights()
model = self.prepare_model(model) model = self.prepare_model(model)
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path) model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
# fix tied broken for no-fsdp-wrap case # fix tied broken for no-fsdp-wrap case
if getattr(model.config, "tie_word_embeddings", None): if getattr(model.config, "tie_word_embeddings", False):
model.tie_weights() model.tie_weights()
self._restore_non_persistent_buffers(model, non_persistent_buffers) self._restore_non_persistent_buffers(model, non_persistent_buffers)
else: else:
model = self.prepare_model(model) model = self.prepare_model(model)
return model return model
def _load_from_dcp(self, model: HFModel, dcp_path: str): def _load_from_dcp(self, model: HFModel, dcp_path: str):

View File

@@ -0,0 +1,24 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .logging_callback import LoggingCallback
from .trainer_callback import CallbackHandler, TrainerCallback, TrainerState
__all__ = [
"CallbackHandler",
"LoggingCallback",
"TrainerCallback",
"TrainerState",
]

View File

@@ -0,0 +1,64 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import os
from typing import TYPE_CHECKING, Any
from .. import logging
from .trainer_callback import TrainerCallback, TrainerState
if TYPE_CHECKING:
from ...config import TrainingArguments
logger = logging.get_logger(__name__)
class LoggingCallback(TrainerCallback):
"""Logs training metrics to stdout on rank-0 and appends to ``state.log_history``.
On each logging step the entry is also persisted as a JSON line in
``<output_dir>/trainer_log.jsonl`` so that training history survives crashes.
"""
def on_log(
self,
args: TrainingArguments,
state: TrainerState,
logs: dict[str, Any],
**kwargs: Any,
) -> None:
# Persist in history regardless of rank
state.log_history.append(dict(logs))
# Everything below is rank-0 only
from ...accelerator.interface import DistributedInterface # lazy import
if DistributedInterface().get_rank() != 0:
return
# Human-readable output to stdout
display_logs = {**logs, "total_steps": state.num_training_steps}
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
logger.info_rank0(parts)
# Append to JSONL log file in output_dir
os.makedirs(args.output_dir, exist_ok=True)
log_file = os.path.join(args.output_dir, "trainer_log.jsonl")
with open(log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(display_logs, ensure_ascii=False) + "\n")

View File

@@ -0,0 +1,147 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ...config import TrainingArguments
@dataclass
class TrainerState:
"""A read-only snapshot of training progress passed to every callback hook.
Attributes:
epoch: Current epoch (0-indexed).
global_step: Number of optimizer steps completed so far.
num_training_steps: Total number of optimizer steps planned.
loss: Scalar loss value of the most recent step.
grad_norm: Gradient-norm value of the most recent step.
learning_rate: Current learning rate seen by the optimizer.
log_history: List of per-step log dicts emitted by ``LoggingCallback``.
"""
epoch: int = 0
global_step: int = 0
num_training_steps: int = 0
loss: float = 0.0
grad_norm: float = 0.0
learning_rate: float = 0.0
log_history: list[dict[str, Any]] = field(default_factory=list)
class TrainerCallback:
"""Abstract base class for training callbacks.
Subclass and override whichever hooks you need. All hooks receive:
- ``args`` the :class:`~llamafactory.v1.config.TrainingArguments`.
- ``state`` a :class:`TrainerState` snapshot (read-only).
- ``**kwargs`` extra keyword arguments (model, optimizer, …).
Callbacks are *observers*: they should NOT mutate training flow.
Hook call order::
on_train_begin
for each epoch:
on_epoch_begin
for each step:
on_step_begin
(forward / backward / optimizer.step)
on_step_end
[on_log] ← if this step is a logging step
on_epoch_end
on_train_end
"""
def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called once before the first training step."""
def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called once after the last training step."""
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called at the beginning of each epoch."""
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called at the end of each epoch."""
def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called before the forward/backward pass of each optimizer step."""
def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called after the optimizer step."""
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any], **kwargs: Any) -> None:
"""Called when the trainer emits a log entry."""
def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
"""Called after the model checkpoint has been written to disk."""
class CallbackHandler:
"""Owns a list of :class:`TrainerCallback` instances and fans out hook calls.
Usage::
handler = CallbackHandler([LoggingCallback(), MyWandbCallback()], trainer=trainer)
handler.on_train_begin(args, state)
"""
def __init__(self, callbacks: list[TrainerCallback] | None = None, trainer: Any = None) -> None:
self.callbacks: list[TrainerCallback] = list(callbacks or [])
self.trainer = trainer
def add_callback(self, callback: TrainerCallback) -> None:
"""Append a callback to the handler."""
self.callbacks.append(callback)
def _call(self, event: str, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
if self.trainer is not None:
kwargs.setdefault("model", getattr(self.trainer, "model", None))
kwargs.setdefault("optimizer", getattr(self.trainer, "optimizer", None))
kwargs.setdefault("lr_scheduler", getattr(self.trainer, "lr_scheduler", None))
kwargs.setdefault("train_dataloader", getattr(self.trainer, "train_batch_generator", None))
for cb in self.callbacks:
getattr(cb, event)(args, state, **kwargs)
def on_train_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_train_begin", args, state)
def on_train_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_train_end", args, state)
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_epoch_begin", args, state)
def on_epoch_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_epoch_end", args, state)
def on_step_begin(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_step_begin", args, state)
def on_step_end(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_step_end", args, state)
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any]) -> None:
self._call("on_log", args, state, logs=logs)
def on_save(self, args: TrainingArguments, state: TrainerState) -> None:
self._call("on_save", args, state)

View File

@@ -57,7 +57,7 @@ TEXT_MESSAGES = [
] ]
VIDEO_MESSAGES = [ VIDEO_MESSAGES = [
{"role": "user", "content": "<video>What is in this viode?"}, {"role": "user", "content": "<video>What is in this video?"},
{"role": "assistant", "content": "A cat."}, {"role": "assistant", "content": "A cat."},
] ]
@@ -210,6 +210,34 @@ def test_gemma3_plugin():
_check_plugin(**check_inputs) _check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("5.6.0"), reason="Requires transformers>=5.6.0")
def test_gemma4_plugin():
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-4-31B-it")
processor = tokenizer_module["processor"]
gemma4_plugin = get_mm_plugin(name="gemma4", image_token="<|image|>", video_token="<|video|>")
check_inputs = {"plugin": gemma4_plugin, **tokenizer_module}
# validate
mm_inputs = gemma4_plugin._get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, processor)
num_image_soft_tokens = 256 # when we use default max_soft_tokens=280
image_token = getattr(processor, "image_token")
boi_token = getattr(processor, "boi_token")
eoi_token = getattr(processor, "eoi_token")
expected_mm_type_ids = [[int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS]
check_inputs["expected_mm_messages"] = [
{"role": "user", "content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?"},
{"role": "assistant", "content": "A cat."},
]
for key in ("num_soft_tokens_per_image",):
mm_inputs.pop(key, None)
mm_inputs["mm_token_type_ids"] = expected_mm_type_ids
check_inputs["expected_mm_inputs"] = mm_inputs
check_inputs["expected_no_mm_inputs"] = {"mm_token_type_ids": expected_mm_type_ids}
_check_plugin(**check_inputs)
@pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.runs_on(["cpu", "mps"])
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") @pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
def test_internvl_plugin(): def test_internvl_plugin():

View File

@@ -0,0 +1,62 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.multiprocessing as mp
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.model_args import ModelArguments
from llamafactory.v1.core.model_engine import ModelEngine
from llamafactory.v1.plugins.model_plugins.parallelization.sequence_parallel import (
SequenceParallelModelPlugin,
sequence_parallel_loss,
)
from llamafactory.v1.utils.env import find_available_port
from llamafactory.v1.utils.pytest import dist_env
def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: int, cp_size: int, dp_size: int):
with dist_env(local_rank, world_size, master_port):
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
# Initialize distributed interface with config
dist_config = {"cp_mode": "ulysses", "cp_size": cp_size, "dp_size": dp_size}
DistributedInterface(dist_config)
# Now create model engine
model_engine = ModelEngine(model_args=model_args)
# Apply sequence parallel plugin
SequenceParallelModelPlugin(dist_config.get("cp_mode", "ulysses"))(model_engine.model, dist_config)
model_inputs = {
"input_ids": torch.tensor([[1, 2, 3, 4, 5]]),
"labels": torch.tensor([[1, 2, 3, 4, 5]]),
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
"position_ids": torch.tensor([[1, 2, 3, 4, 5]]),
"loss_weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]),
}
loss = sequence_parallel_loss(model_engine.model, model_inputs)
assert loss is not None
@pytest.mark.runs_on(["cuda", "npu"])
@pytest.mark.require_distributed(2)
@pytest.mark.parametrize("cp_size, dp_size", [(2, 1)])
def test_sequence_parallel_loss(cp_size, dp_size):
master_port = find_available_port()
world_size = cp_size * dp_size
mp.spawn(_test_sequence_parallel_loss, args=(world_size, master_port, cp_size, dp_size), nprocs=world_size)