mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-17 02:16:02 +08:00
Compare commits
12 Commits
97d479fa92
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
436d26bc28 | ||
|
|
c109c061e5 | ||
|
|
fa09c01c36 | ||
|
|
eae6f0b541 | ||
|
|
acac63ef35 | ||
|
|
e5e8546493 | ||
|
|
97433c53b6 | ||
|
|
b5afabe3d2 | ||
|
|
df2e6edb7e | ||
|
|
d02fcd3588 | ||
|
|
c340aa2a33 | ||
|
|
1e536733c6 |
105
.ai/CLAUDE.md
Normal file
105
.ai/CLAUDE.md
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Code style (auto-fix)
|
||||||
|
make style
|
||||||
|
|
||||||
|
# Code quality check (no modifications)
|
||||||
|
make quality
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
make test
|
||||||
|
|
||||||
|
# Run a single test file
|
||||||
|
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/path/to/test_file.py
|
||||||
|
|
||||||
|
# Run tests matching a pattern
|
||||||
|
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ -k "test_name"
|
||||||
|
|
||||||
|
# License header check
|
||||||
|
make license
|
||||||
|
|
||||||
|
# Build package
|
||||||
|
make build
|
||||||
|
```
|
||||||
|
|
||||||
|
The project uses `uv` as the preferred package manager. Commands automatically use `uv run` / `uvx` if `uv` is available.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
LlamaFactory has two parallel architectures controlled by the `USE_V1` environment variable:
|
||||||
|
|
||||||
|
- **v0 (default):** `api, webui > chat, eval, train > data, model > hparams > extras`
|
||||||
|
- **v1 (experimental, `USE_V1=1`):** `trainers > core > accelerator, plugins, config > utils`
|
||||||
|
|
||||||
|
Most active development happens in v0. The v1 architecture lives in `src/llamafactory/v1/`.
|
||||||
|
|
||||||
|
### Entry Points
|
||||||
|
|
||||||
|
CLI entry point is `llamafactory-cli` / `lmf` → `src/llamafactory/cli.py:main()`, which dispatches to `launcher.py` based on `USE_V1`.
|
||||||
|
|
||||||
|
Available subcommands: `train`, `chat`, `api`, `export`, `webchat`, `webui`, `env`, `version`, `help`.
|
||||||
|
|
||||||
|
### Training Flow (v0)
|
||||||
|
|
||||||
|
```
|
||||||
|
run_exp() [tuner.py]
|
||||||
|
→ read_args() → parse YAML/JSON config
|
||||||
|
→ get_train_args() → produces typed argument dataclasses
|
||||||
|
→ routes to: run_sft / run_dpo / run_ppo / run_rm / run_pt / run_kto
|
||||||
|
→ optional: export_model()
|
||||||
|
```
|
||||||
|
|
||||||
|
Training is invoked with a YAML config: `llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml`
|
||||||
|
|
||||||
|
### Configuration System
|
||||||
|
|
||||||
|
All training parameters are YAML/JSON config files. Argument parsing in `src/llamafactory/hparams/parser.py` produces four typed dataclasses:
|
||||||
|
- `ModelArguments` — model/tokenizer selection, quantization
|
||||||
|
- `DataArguments` — datasets, templates, preprocessing
|
||||||
|
- `FinetuningArguments` — LoRA rank/target, training method (sft/dpo/ppo/rm/pt/kto)
|
||||||
|
- `TrainingArguments` — extends HuggingFace's `TrainingArguments`
|
||||||
|
|
||||||
|
### Key Modules
|
||||||
|
|
||||||
|
| Module | Purpose |
|
||||||
|
|--------|---------|
|
||||||
|
| `src/llamafactory/model/loader.py` | Loads model + tokenizer; applies quantization, LoRA, patches |
|
||||||
|
| `src/llamafactory/model/patcher.py` | Model-specific compatibility patches |
|
||||||
|
| `src/llamafactory/data/template.py` | Prompt templates; `TEMPLATES` dict maps model family → format |
|
||||||
|
| `src/llamafactory/data/mm_plugin.py` | Multi-modal (image/video/audio) data handling |
|
||||||
|
| `src/llamafactory/data/processor/` | Per-stage data processors (supervised, pairwise, pretrain, etc.) |
|
||||||
|
| `src/llamafactory/train/sft/` | SFT trainer; other stages follow same structure |
|
||||||
|
| `src/llamafactory/chat/` | Inference engines: `hf_engine`, `vllm_engine`, `sglang_engine`, `kt_engine` |
|
||||||
|
| `src/llamafactory/extras/constants.py` | Enums and constants used across the project |
|
||||||
|
|
||||||
|
### Adding Support for a New Model
|
||||||
|
|
||||||
|
1. Add a prompt template to `src/llamafactory/data/template.py` in the `TEMPLATES` dict
|
||||||
|
2. Add any necessary model patches in `src/llamafactory/model/patcher.py`
|
||||||
|
3. Add multi-modal support in `src/llamafactory/data/mm_plugin.py` if needed
|
||||||
|
|
||||||
|
### Distributed Training
|
||||||
|
|
||||||
|
Multi-GPU automatically uses `torchrun`. Additional backends:
|
||||||
|
- **Ray:** Optional Ray cluster support
|
||||||
|
- **HyperParallel FSDP2:** `src/llamafactory/train/hyper_parallel/`
|
||||||
|
- **Megatron-core:** `src/llamafactory/train/mca/`
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
- `tests/` — v0 tests; `tests_v1/` — v1 tests
|
||||||
|
- Most training tests require GPU hardware
|
||||||
|
- pytest markers: `@pytest.mark.slow`, `@pytest.mark.runs_on(['cuda'])`
|
||||||
|
- Always set `WANDB_DISABLED=true` when running tests
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
- Ruff for linting and formatting (line length 119, Google-style docstrings)
|
||||||
|
- Python 3.11+ syntax
|
||||||
|
- Double quotes for strings
|
||||||
|
- All new files must include Apache 2.0 license header (checked by `make license`)
|
||||||
6
.github/workflows/tests_npu.yml
vendored
6
.github/workflows/tests_npu.yml
vendored
@@ -49,6 +49,12 @@ jobs:
|
|||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v6
|
uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Set nginx-cache for Ascend CI
|
||||||
|
run: |
|
||||||
|
sed -Ei 's@(ports|archive).ubuntu.com@cache-service.nginx-pypi-cache.svc.cluster.local:8081@g' /etc/apt/sources.list
|
||||||
|
pip config set global.index-url http://cache-service.nginx-pypi-cache.svc.cluster.local/pypi/simple
|
||||||
|
pip config set global.trusted-host cache-service.nginx-pypi-cache.svc.cluster.local
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v7
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
model: Qwen/Qwen3-4B
|
model: Qwen/Qwen3-4B
|
||||||
trust_remote_code: true
|
|
||||||
model_class: llm
|
model_class: llm
|
||||||
|
|
||||||
template: qwen3_nothink
|
template: qwen3_nothink
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
model: Qwen/Qwen3-0.6B
|
model: Qwen/Qwen3-0.6B
|
||||||
|
|
||||||
model_class: llm
|
model_class: llm
|
||||||
|
|
||||||
template: qwen3_nothink
|
template: qwen3_nothink
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
model: Qwen/Qwen3-0.6B
|
model: Qwen/Qwen3-0.6B
|
||||||
trust_remote_code: true
|
|
||||||
model_class: llm
|
model_class: llm
|
||||||
|
|
||||||
template: qwen3_nothink
|
template: qwen3_nothink
|
||||||
|
|||||||
23
examples/v1/train_full/train_full_ulysses_cp.yaml
Normal file
23
examples/v1/train_full/train_full_ulysses_cp.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
model: Qwen/Qwen3-0.6B
|
||||||
|
trust_remote_code: true
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
cp_mode: ulysses
|
||||||
|
cp_size: 2
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: outputs/test_ulysses_cp
|
||||||
|
micro_batch_size: 1
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: false
|
||||||
|
max_steps: 10
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
model: Qwen/Qwen3-4B
|
model: Qwen/Qwen3-4B
|
||||||
trust_remote_code: true
|
|
||||||
model_class: llm
|
model_class: llm
|
||||||
|
|
||||||
template: qwen3_nothink
|
template: qwen3_nothink
|
||||||
@@ -28,7 +27,6 @@ train_dataset: data/v1_sft_demo.yaml
|
|||||||
### training
|
### training
|
||||||
output_dir: ./outputs/test_lora
|
output_dir: ./outputs/test_lora
|
||||||
micro_batch_size: 1
|
micro_batch_size: 1
|
||||||
global_batch_size: 4
|
|
||||||
cutoff_len: 2048
|
cutoff_len: 2048
|
||||||
learning_rate: 1.0e-4
|
learning_rate: 1.0e-4
|
||||||
bf16: true
|
bf16: true
|
||||||
|
|||||||
40
examples/v1/train_lora/train_lora_sft_rank0.yaml
Normal file
40
examples/v1/train_lora/train_lora_sft_rank0.yaml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
model: Qwen/Qwen3-4B
|
||||||
|
model_class: llm
|
||||||
|
|
||||||
|
template: qwen3_nothink
|
||||||
|
|
||||||
|
# PEFT Configuration
|
||||||
|
peft_config:
|
||||||
|
name: lora
|
||||||
|
r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.05
|
||||||
|
target_modules: all
|
||||||
|
|
||||||
|
# Kernel Config
|
||||||
|
kernel_config:
|
||||||
|
name: auto
|
||||||
|
include_kernels: auto
|
||||||
|
|
||||||
|
# FSDP Config
|
||||||
|
dist_config:
|
||||||
|
name: fsdp2
|
||||||
|
dcp_path: null
|
||||||
|
|
||||||
|
init_config:
|
||||||
|
name: init_on_rank0
|
||||||
|
|
||||||
|
### data
|
||||||
|
train_dataset: data/v1_sft_demo.yaml
|
||||||
|
|
||||||
|
### training
|
||||||
|
output_dir: ./outputs/test_lora
|
||||||
|
micro_batch_size: 1
|
||||||
|
cutoff_len: 2048
|
||||||
|
learning_rate: 1.0e-4
|
||||||
|
bf16: true
|
||||||
|
max_steps: 10
|
||||||
|
|
||||||
|
### sample
|
||||||
|
sample_backend: hf
|
||||||
|
max_new_tokens: 128
|
||||||
@@ -1,5 +1,4 @@
|
|||||||
model: Qwen/Qwen3-0.6B
|
model: Qwen/Qwen3-0.6B
|
||||||
trust_remote_code: true
|
|
||||||
model_class: llm
|
model_class: llm
|
||||||
|
|
||||||
template: qwen3_nothink
|
template: qwen3_nothink
|
||||||
|
|||||||
@@ -380,6 +380,19 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
for i, feature in enumerate(features):
|
for i, feature in enumerate(features):
|
||||||
feature["token_type_ids"] = token_type_ids[i]
|
feature["token_type_ids"] = token_type_ids[i]
|
||||||
|
|
||||||
|
if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4
|
||||||
|
mm_token_type_ids = mm_inputs.pop("mm_token_type_ids")
|
||||||
|
max_len = max(len(ids) for ids in mm_token_type_ids)
|
||||||
|
padded = []
|
||||||
|
for ids in mm_token_type_ids:
|
||||||
|
pad_len = max_len - len(ids)
|
||||||
|
if self.tokenizer.padding_side == "right":
|
||||||
|
padded.append(ids + [0] * pad_len)
|
||||||
|
else:
|
||||||
|
padded.append([0] * pad_len + ids)
|
||||||
|
|
||||||
|
mm_inputs["mm_token_type_ids"] = torch.tensor(padded, dtype=torch.long)
|
||||||
|
|
||||||
features: dict[str, torch.Tensor] = super().__call__(features)
|
features: dict[str, torch.Tensor] = super().__call__(features)
|
||||||
|
|
||||||
bsz, seq_len = features["input_ids"].shape[:2]
|
bsz, seq_len = features["input_ids"].shape[:2]
|
||||||
|
|||||||
@@ -607,6 +607,194 @@ class Gemma3nPlugin(Gemma3Plugin):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Gemma4Plugin(BasePlugin):
|
||||||
|
r"""Plugin for the Gemma4 multimodal model."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
|
r"""Regularize videos, also tracking per-video FPS and frame indices for timestamp generation."""
|
||||||
|
results, fps_per_video, durations, frames_indices = [], [], [], []
|
||||||
|
for video in videos:
|
||||||
|
frames: list[ImageObject] = []
|
||||||
|
if _check_video_is_nested_images(video):
|
||||||
|
frames = video
|
||||||
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
frames_indices.append(list(range(len(frames))))
|
||||||
|
else:
|
||||||
|
container = av.open(video, "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
|
original_fps = float(video_stream.average_rate)
|
||||||
|
# for correctly calculate timestamps
|
||||||
|
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices])
|
||||||
|
container.seek(0)
|
||||||
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
|
if frame_idx in sample_indices:
|
||||||
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
if video_stream.duration is None:
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
else:
|
||||||
|
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
|
results.append(frames)
|
||||||
|
|
||||||
|
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||||
|
image_processor = getattr(processor, "image_processor", None)
|
||||||
|
video_processor = getattr(processor, "video_processor", None)
|
||||||
|
feature_extractor = getattr(processor, "feature_extractor", None)
|
||||||
|
mm_inputs = {}
|
||||||
|
|
||||||
|
if len(images) != 0:
|
||||||
|
regularized = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||||
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||||
|
)["images"]
|
||||||
|
mm_inputs.update(image_processor(regularized, return_tensors="pt"))
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
video_data = self._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||||
|
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||||
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
|
)
|
||||||
|
video_metadata = [
|
||||||
|
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||||
|
for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"])
|
||||||
|
]
|
||||||
|
mm_inputs.update(
|
||||||
|
video_processor(
|
||||||
|
videos=video_data["videos"],
|
||||||
|
video_metadata=video_metadata,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_metadata=True,
|
||||||
|
do_sample_frames=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(audios) != 0: # only for gemma4n
|
||||||
|
audios = self._regularize_audios(
|
||||||
|
audios,
|
||||||
|
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||||
|
)["audios"]
|
||||||
|
|
||||||
|
mm_inputs.update(
|
||||||
|
feature_extractor(
|
||||||
|
audios,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
|
||||||
|
boi_token: str = getattr(processor, "boi_token")
|
||||||
|
eoi_token: str = getattr(processor, "eoi_token")
|
||||||
|
boa_token: str = getattr(processor, "boa_token")
|
||||||
|
eoa_token: str = getattr(processor, "eoa_token")
|
||||||
|
image_token: str = getattr(processor, "image_token")
|
||||||
|
video_token: str = getattr(processor, "video_token")
|
||||||
|
audio_token: str = getattr(processor, "audio_token")
|
||||||
|
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
num_image_soft_tokens: list[int] = list(
|
||||||
|
mm_inputs.get("num_soft_tokens_per_image", [getattr(processor, "image_seq_length", 256)] * len(images))
|
||||||
|
)
|
||||||
|
num_video_soft_tokens: list[int] = list(mm_inputs.get("num_soft_tokens_per_video", [1] * len(videos)))
|
||||||
|
video_metadata = mm_inputs.get("video_metadata", [])
|
||||||
|
else:
|
||||||
|
num_image_soft_tokens = [1] * len(images)
|
||||||
|
num_video_soft_tokens = [1] * len(videos)
|
||||||
|
video_metadata = [None] * len(videos)
|
||||||
|
|
||||||
|
audio_iter = iter(audios)
|
||||||
|
image_iter = iter(num_image_soft_tokens)
|
||||||
|
video_iter = iter(zip(num_video_soft_tokens, video_metadata))
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
n = next(image_iter)
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, f"{boi_token}{image_token * n}{eoi_token}", 1)
|
||||||
|
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
num_soft_tokens_per_frame, metadata = next(video_iter)
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
|
||||||
|
frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs]
|
||||||
|
video_str = " ".join(frame_strs)
|
||||||
|
else:
|
||||||
|
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, video_str, 1)
|
||||||
|
|
||||||
|
while AUDIO_PLACEHOLDER in content:
|
||||||
|
current_audio = next(audio_iter)
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate)
|
||||||
|
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
|
||||||
|
else:
|
||||||
|
audio_str = f"{boa_token}{audio_token}{eoa_token}"
|
||||||
|
|
||||||
|
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
imglens: list[int],
|
||||||
|
vidlens: list[int],
|
||||||
|
audlens: list[int],
|
||||||
|
batch_ids: list[list[int]],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
# Pop metadata keys that must not be passed to the model.
|
||||||
|
for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata",
|
||||||
|
"_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"):
|
||||||
|
mm_inputs.pop(key, None)
|
||||||
|
|
||||||
|
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InternVLPlugin(BasePlugin):
|
class InternVLPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@@ -1489,10 +1677,11 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
results, fps_per_video, durations = [], [], []
|
results, fps_per_video, durations, frames_indices = [], [], [], []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
if _check_video_is_nested_images(video):
|
if _check_video_is_nested_images(video):
|
||||||
|
# we assume already sample frames from videos
|
||||||
for frame in video:
|
for frame in video:
|
||||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||||
raise ValueError("Invalid image found in video frames.")
|
raise ValueError("Invalid image found in video frames.")
|
||||||
@@ -1500,10 +1689,14 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
frames = video
|
frames = video
|
||||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
frames_indices.append(list(range(len(frames))))
|
||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
|
original_fps = float(video_stream.average_rate)
|
||||||
|
# for qwen3vl video timestamp calculation
|
||||||
|
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
|
||||||
container.seek(0)
|
container.seek(0)
|
||||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
if frame_idx in sample_indices:
|
if frame_idx in sample_indices:
|
||||||
@@ -1522,7 +1715,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
|
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
@@ -1637,8 +1830,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
|
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||||
for video, duration in zip(videos["videos"], videos["durations"])
|
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(
|
||||||
video_processor(
|
video_processor(
|
||||||
@@ -1646,6 +1839,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_metadata=video_metadata,
|
video_metadata=video_metadata,
|
||||||
fps=getattr(processor, "video_fps", 2.0),
|
fps=getattr(processor, "video_fps", 2.0),
|
||||||
return_metadata=True,
|
return_metadata=True,
|
||||||
|
do_sample_frames=False, # avoid changing frames_indices
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||||
@@ -1677,7 +1871,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
||||||
video_metadata = mm_inputs.get("video_metadata", {})
|
video_metadata = mm_inputs.get("video_metadata", [])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
image_grid_thw = [None] * len(images)
|
image_grid_thw = [None] * len(images)
|
||||||
@@ -2200,8 +2394,9 @@ PLUGINS = {
|
|||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"ernie_vl": ErnieVLPlugin,
|
"ernie_vl": ErnieVLPlugin,
|
||||||
"gemma3": Gemma3Plugin,
|
"gemma3": Gemma3Plugin,
|
||||||
"glm4v": GLM4VPlugin,
|
|
||||||
"gemma3n": Gemma3nPlugin,
|
"gemma3n": Gemma3nPlugin,
|
||||||
|
"gemma4": Gemma4Plugin,
|
||||||
|
"glm4v": GLM4VPlugin,
|
||||||
"intern_vl": InternVLPlugin,
|
"intern_vl": InternVLPlugin,
|
||||||
"kimi_vl": KimiVLPlugin,
|
"kimi_vl": KimiVLPlugin,
|
||||||
"llama4": Llama4Plugin,
|
"llama4": Llama4Plugin,
|
||||||
|
|||||||
@@ -997,6 +997,55 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="gemma4",
|
||||||
|
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
|
||||||
|
), # seem not consistent with the chattemplate
|
||||||
|
format_tools=ToolFormatter(tool_format="gemma4"),
|
||||||
|
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<turn|>"],
|
||||||
|
default_system="You are a helpful assistant.", # important for thinking
|
||||||
|
thought_words=("<|channel>thought\n", "<channel|>"),
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(
|
||||||
|
"gemma4",
|
||||||
|
image_token="<|image|>",
|
||||||
|
video_token="<|video|>",
|
||||||
|
),
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="gemma4n",
|
||||||
|
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="gemma4"),
|
||||||
|
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<turn|>"],
|
||||||
|
default_system="You are a helpful assistant.", # important for thinking
|
||||||
|
thought_words=("<|channel>thought\n", "<channel|>"),
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(
|
||||||
|
"gemma4",
|
||||||
|
image_token="<|image|>",
|
||||||
|
video_token="<|video|>",
|
||||||
|
audio_token="<|audio|>",
|
||||||
|
),
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="glm4",
|
name="glm4",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
|
|||||||
@@ -209,6 +209,164 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
class Gemma4ToolUtils(ToolUtils):
|
||||||
|
r"""Gemma-4 tool using template."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||||
|
def _format_parameters(properties: dict[str, Any]) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
for name, schema in properties.items():
|
||||||
|
item_parts: list[str] = []
|
||||||
|
if schema.get("description"):
|
||||||
|
item_parts.append(f'description:<|"|>{schema["description"]}<|"|>')
|
||||||
|
if schema.get("type"):
|
||||||
|
item_parts.append(f'type:<|"|>{str(schema["type"]).upper()}<|"|>')
|
||||||
|
parts.append(f"{name}:{{{','.join(item_parts)}}}")
|
||||||
|
|
||||||
|
return ",".join(parts)
|
||||||
|
|
||||||
|
declarations: list[str] = []
|
||||||
|
for tool in tools:
|
||||||
|
function_data = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||||
|
declaration = (
|
||||||
|
f"declaration:{function_data['name']}"
|
||||||
|
+ "{"
|
||||||
|
+ f'description:<|"|>{function_data.get("description", "")}<|"|>'
|
||||||
|
)
|
||||||
|
|
||||||
|
params = function_data.get("parameters")
|
||||||
|
if params:
|
||||||
|
param_parts: list[str] = []
|
||||||
|
if params.get("properties"):
|
||||||
|
param_parts.append(f"properties:{{{_format_parameters(params['properties'])}}}")
|
||||||
|
|
||||||
|
if params.get("required"):
|
||||||
|
required_text = ",".join(f'<|"|>{item}<|"|>' for item in params["required"])
|
||||||
|
param_parts.append(f"required:[{required_text}]")
|
||||||
|
|
||||||
|
if params.get("type"):
|
||||||
|
param_parts.append(f'type:<|"|>{str(params["type"]).upper()}<|"|>')
|
||||||
|
|
||||||
|
declaration += f",parameters:{{{','.join(param_parts)}}}"
|
||||||
|
|
||||||
|
response_declaration = function_data.get("response")
|
||||||
|
if response_declaration:
|
||||||
|
response_parts: list[str] = []
|
||||||
|
if response_declaration.get("description"):
|
||||||
|
response_parts.append(f'description:<|"|>{response_declaration["description"]}<|"|>')
|
||||||
|
|
||||||
|
response_type = str(response_declaration.get("type", "")).upper()
|
||||||
|
|
||||||
|
if response_type == "OBJECT":
|
||||||
|
response_parts.append(f'type:<|"|>{response_type}<|"|>')
|
||||||
|
|
||||||
|
declaration += f",response:{{{','.join(response_parts)}}}"
|
||||||
|
|
||||||
|
declarations.append(declaration + "}")
|
||||||
|
|
||||||
|
return "\n".join(declarations)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||||
|
regex = re.compile(r"<\|tool_call\>call:([^{\s]+)\{(.*?)\}<tool_call\|>", re.DOTALL)
|
||||||
|
matches = re.findall(regex, content)
|
||||||
|
if not matches:
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _parse_arguments(arg_text: str) -> Any:
|
||||||
|
text = arg_text.strip()
|
||||||
|
if not text:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# `function_formatter` writes dict arguments as `k:v,...` inside `{...}`.
|
||||||
|
# The extractor captures only the inner text, so re-wrap it to parse as JSON object.
|
||||||
|
object_like_text = "{" + text + "}"
|
||||||
|
# Convert Gemma string markers (<|"|>value<|"|>) to valid JSON strings.
|
||||||
|
normalized = re.sub(
|
||||||
|
r"<\|\"\|\>(.*?)<\|\"\|\>",
|
||||||
|
lambda m: json.dumps(m.group(1), ensure_ascii=False),
|
||||||
|
object_like_text,
|
||||||
|
flags=re.DOTALL,
|
||||||
|
)
|
||||||
|
# Quote unquoted object keys so the payload can be parsed by json.loads.
|
||||||
|
normalized = re.sub(r'(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)', r'\1"\2"\3', normalized)
|
||||||
|
try:
|
||||||
|
return json.loads(normalized)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return text
|
||||||
|
|
||||||
|
results: list[FunctionCall] = []
|
||||||
|
for name, arg_block in matches:
|
||||||
|
parsed_arguments = _parse_arguments(arg_block)
|
||||||
|
if isinstance(parsed_arguments, str):
|
||||||
|
arguments = parsed_arguments
|
||||||
|
else:
|
||||||
|
arguments = json.dumps(parsed_arguments, ensure_ascii=False)
|
||||||
|
results.append(FunctionCall(name.strip(), arguments))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||||
|
def _format_argument(argument: Any, escape_keys: bool = True) -> str:
|
||||||
|
if isinstance(argument, str):
|
||||||
|
return f'<|"|>{argument}<|"|>'
|
||||||
|
|
||||||
|
if isinstance(argument, bool):
|
||||||
|
return "true" if argument else "false"
|
||||||
|
|
||||||
|
if isinstance(argument, dict):
|
||||||
|
items: list[str] = []
|
||||||
|
for key in sorted(argument.keys()):
|
||||||
|
formatted_key = f'<|"|>{key}<|"|>' if escape_keys else str(key)
|
||||||
|
formatted_value = _format_argument(argument[key], escape_keys=escape_keys)
|
||||||
|
items.append(f"{formatted_key}:{formatted_value}")
|
||||||
|
return "{" + ",".join(items) + "}"
|
||||||
|
|
||||||
|
if isinstance(argument, (list, tuple)):
|
||||||
|
return "[" + ",".join(_format_argument(item, escape_keys=escape_keys) for item in argument) + "]"
|
||||||
|
|
||||||
|
if argument is None:
|
||||||
|
return "null"
|
||||||
|
|
||||||
|
return str(argument)
|
||||||
|
|
||||||
|
function_texts: list[str] = []
|
||||||
|
for function in functions:
|
||||||
|
name = function.name
|
||||||
|
raw_arguments = function.arguments
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_arguments = json.loads(raw_arguments)
|
||||||
|
except (TypeError, json.JSONDecodeError):
|
||||||
|
parsed_arguments = raw_arguments
|
||||||
|
|
||||||
|
call_text = f"<|tool_call>call:{name}" + "{"
|
||||||
|
if isinstance(parsed_arguments, dict):
|
||||||
|
args_text = []
|
||||||
|
for key in sorted(parsed_arguments.keys()):
|
||||||
|
value_text = _format_argument(parsed_arguments[key], escape_keys=False)
|
||||||
|
args_text.append(f"{key}:{value_text}")
|
||||||
|
|
||||||
|
call_text += ",".join(args_text)
|
||||||
|
elif isinstance(parsed_arguments, str):
|
||||||
|
call_text += parsed_arguments
|
||||||
|
else:
|
||||||
|
call_text += _format_argument(parsed_arguments, escape_keys=False)
|
||||||
|
|
||||||
|
call_text += "}<tool_call|>"
|
||||||
|
function_texts.append(call_text)
|
||||||
|
|
||||||
|
return "".join(function_texts)
|
||||||
|
|
||||||
class GLM4ToolUtils(ToolUtils):
|
class GLM4ToolUtils(ToolUtils):
|
||||||
r"""GLM-4 tool using template."""
|
r"""GLM-4 tool using template."""
|
||||||
@@ -361,6 +519,8 @@ class MiniMaxM2ToolUtils(ToolUtils):
|
|||||||
prompt += "\n</invoke>"
|
prompt += "\n</invoke>"
|
||||||
function_texts.append(prompt)
|
function_texts.append(prompt)
|
||||||
|
|
||||||
|
return "\n".join(function_texts)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||||
@@ -721,6 +881,7 @@ class LFM2ToolUtils(ToolUtils):
|
|||||||
|
|
||||||
TOOLS = {
|
TOOLS = {
|
||||||
"default": DefaultToolUtils(),
|
"default": DefaultToolUtils(),
|
||||||
|
"gemma4": Gemma4ToolUtils(),
|
||||||
"glm4": GLM4ToolUtils(),
|
"glm4": GLM4ToolUtils(),
|
||||||
"llama3": Llama3ToolUtils(),
|
"llama3": Llama3ToolUtils(),
|
||||||
"lfm2": LFM2ToolUtils(),
|
"lfm2": LFM2ToolUtils(),
|
||||||
|
|||||||
@@ -865,6 +865,34 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Gemma-4-26B-A4B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-4-26B-A4B-it",
|
||||||
|
},
|
||||||
|
"Gemma-4-31B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-4-31B-it",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="gemma4",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Gemma-4-E2B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-4-E2B-it",
|
||||||
|
},
|
||||||
|
"Gemma-4-E4B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="gemma4n",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"GLM-4-9B": {
|
"GLM-4-9B": {
|
||||||
|
|||||||
@@ -70,6 +70,10 @@ def is_matplotlib_available():
|
|||||||
return _is_package_available("matplotlib")
|
return _is_package_available("matplotlib")
|
||||||
|
|
||||||
|
|
||||||
|
def is_hyper_parallel_available():
|
||||||
|
return _is_package_available("hyper_parallel")
|
||||||
|
|
||||||
|
|
||||||
def is_mcore_adapter_available():
|
def is_mcore_adapter_available():
|
||||||
return _is_package_available("mcore_adapter")
|
return _is_package_available("mcore_adapter")
|
||||||
|
|
||||||
|
|||||||
@@ -482,6 +482,24 @@ class FinetuningArguments(
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
use_hyper_parallel: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Whether or not to use HyperParallel distributed training backend (FSDP/TP). "
|
||||||
|
"Only supported for the 'sft' stage with full fine-tuning."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
hyper_parallel_args: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": (
|
||||||
|
"Path to a JSON file containing HyperParallel strategy arguments "
|
||||||
|
"(e.g., tp_size, param_dtype). Used when use_hyper_parallel=True."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
use_muon: bool = field(
|
use_muon: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the Muon optimizer."},
|
metadata={"help": "Whether or not to use the Muon optimizer."},
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ def _setup_freeze_tuning(
|
|||||||
|
|
||||||
model_type = getattr(model.config, "model_type", None)
|
model_type = getattr(model.config, "model_type", None)
|
||||||
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
|
if not finetuning_args.freeze_multi_modal_projector and model_type in COMPOSITE_MODELS:
|
||||||
trainable_layers.append(COMPOSITE_MODELS[model_type].projector_key)
|
trainable_layers.extend(COMPOSITE_MODELS[model_type].projector_keys)
|
||||||
|
|
||||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ def apply_liger_kernel(
|
|||||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel
|
||||||
elif model_type == "gemma3_text":
|
elif model_type == "gemma3_text":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel
|
||||||
elif model_type == "glm4":
|
elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel
|
||||||
elif model_type == "glm4v":
|
elif model_type == "glm4v":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
forbidden_modules.add("output")
|
forbidden_modules.add("output")
|
||||||
|
|
||||||
if model_type in COMPOSITE_MODELS:
|
if model_type in COMPOSITE_MODELS:
|
||||||
forbidden_modules.add(COMPOSITE_MODELS[model_type].projector_key)
|
forbidden_modules.update(COMPOSITE_MODELS[model_type].projector_keys)
|
||||||
|
|
||||||
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
|
if freeze_vision_tower and model_type in COMPOSITE_MODELS:
|
||||||
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)
|
forbidden_modules.update(COMPOSITE_MODELS[model_type].vision_model_keys)
|
||||||
|
|||||||
@@ -147,6 +147,11 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
|
|||||||
|
|
||||||
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
|
_set_z3_leaf_modules(model, [Qwen3NextSparseMoeBlock])
|
||||||
|
|
||||||
|
if model_type == "qwen3_5_moe":
|
||||||
|
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock
|
||||||
|
|
||||||
|
_set_z3_leaf_modules(model, [Qwen3_5MoeSparseMoeBlock])
|
||||||
|
|
||||||
|
|
||||||
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
if not is_trainable or not model_args.moe_aux_loss_coef:
|
if not is_trainable or not model_args.moe_aux_loss_coef:
|
||||||
|
|||||||
@@ -39,16 +39,26 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
|
|||||||
@dataclass
|
@dataclass
|
||||||
class CompositeModel:
|
class CompositeModel:
|
||||||
model_type: str
|
model_type: str
|
||||||
projector_key: str
|
projector_keys: list[str]
|
||||||
vision_model_keys: list[str]
|
vision_model_keys: list[str]
|
||||||
language_model_keys: list[str]
|
language_model_keys: list[str]
|
||||||
lora_conflict_keys: list[str]
|
lora_conflict_keys: list[str]
|
||||||
|
|
||||||
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
|
|
||||||
for key in self.projector_key.split("."):
|
|
||||||
module = getattr(module, key)
|
|
||||||
|
|
||||||
return module
|
def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
|
||||||
|
mm_projectors: list[torch.nn.Module] = []
|
||||||
|
for projector_key in self.projector_keys:
|
||||||
|
project_module = module
|
||||||
|
for key in projector_key.split("."):
|
||||||
|
project_module = getattr(project_module, key, None)
|
||||||
|
if project_module is None: # i,e gemma4 bigger one, there is no embed_audio
|
||||||
|
logger.warning_rank0(f"Projector key {projector_key} not found in module {module.__class__.__name__}.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if project_module is not None:
|
||||||
|
mm_projectors.append(project_module)
|
||||||
|
|
||||||
|
return mm_projectors
|
||||||
|
|
||||||
|
|
||||||
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
||||||
@@ -56,7 +66,7 @@ COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
|||||||
|
|
||||||
def _register_composite_model(
|
def _register_composite_model(
|
||||||
model_type: str,
|
model_type: str,
|
||||||
projector_key: Optional[str] = None,
|
projector_keys: list[str] | None = None,
|
||||||
vision_model_keys: Optional[list[str]] = None,
|
vision_model_keys: Optional[list[str]] = None,
|
||||||
language_model_keys: Optional[list[str]] = None,
|
language_model_keys: Optional[list[str]] = None,
|
||||||
lora_conflict_keys: Optional[list[str]] = None,
|
lora_conflict_keys: Optional[list[str]] = None,
|
||||||
@@ -65,7 +75,7 @@ def _register_composite_model(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type: model type
|
model_type: model type
|
||||||
projector_key: multi_modal_projector
|
projector_keys: multi_modal_projector
|
||||||
vision_model_keys: vision_tower
|
vision_model_keys: vision_tower
|
||||||
language_model_keys: language_model
|
language_model_keys: language_model
|
||||||
lora_conflict_keys: None
|
lora_conflict_keys: None
|
||||||
@@ -73,7 +83,7 @@ def _register_composite_model(
|
|||||||
"""
|
"""
|
||||||
COMPOSITE_MODELS[model_type] = CompositeModel(
|
COMPOSITE_MODELS[model_type] = CompositeModel(
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
projector_key=projector_key or "multi_modal_projector",
|
projector_keys=projector_keys or ["multi_modal_projector"],
|
||||||
vision_model_keys=vision_model_keys or ["vision_tower"],
|
vision_model_keys=vision_model_keys or ["vision_tower"],
|
||||||
language_model_keys=language_model_keys or ["language_model", "lm_head"],
|
language_model_keys=language_model_keys or ["language_model", "lm_head"],
|
||||||
lora_conflict_keys=lora_conflict_keys or [],
|
lora_conflict_keys=lora_conflict_keys or [],
|
||||||
@@ -136,12 +146,16 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
|||||||
if getattr(model, "quantization_method", None):
|
if getattr(model, "quantization_method", None):
|
||||||
model_type = getattr(model.config, "model_type", None)
|
model_type = getattr(model.config, "model_type", None)
|
||||||
if model_type in COMPOSITE_MODELS:
|
if model_type in COMPOSITE_MODELS:
|
||||||
mm_projector = COMPOSITE_MODELS[model_type].get_projector(model)
|
mm_projectors = COMPOSITE_MODELS[model_type].get_projectors(model)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
logger.info_rank0(
|
||||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
f"Casting multimodal projector outputs in {model_args.compute_dtype}: "
|
||||||
|
f"{COMPOSITE_MODELS[model_type].projector_keys}."
|
||||||
|
)
|
||||||
|
for mm_projector in mm_projectors:
|
||||||
|
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||||
|
|
||||||
|
|
||||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||||
@@ -166,9 +180,9 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
|||||||
forbidden_modules.update(vision_model_keys)
|
forbidden_modules.update(vision_model_keys)
|
||||||
|
|
||||||
if finetuning_args.freeze_multi_modal_projector:
|
if finetuning_args.freeze_multi_modal_projector:
|
||||||
projector_key = COMPOSITE_MODELS[model_type].projector_key
|
projector_keys = COMPOSITE_MODELS[model_type].projector_keys
|
||||||
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
|
logger.info_rank0(f"Set multi model projector not trainable: {projector_keys}.")
|
||||||
forbidden_modules.add(projector_key)
|
forbidden_modules.update(projector_keys)
|
||||||
|
|
||||||
if finetuning_args.freeze_language_model:
|
if finetuning_args.freeze_language_model:
|
||||||
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
|
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
|
||||||
@@ -200,7 +214,7 @@ def patch_target_modules(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="dots_ocr",
|
model_type="dots_ocr",
|
||||||
projector_key="vision_tower.merger",
|
projector_keys=["vision_tower.merger"],
|
||||||
vision_model_keys=["vision_tower"],
|
vision_model_keys=["vision_tower"],
|
||||||
language_model_keys=["model", "lm_head"],
|
language_model_keys=["model", "lm_head"],
|
||||||
lora_conflict_keys=["merger"],
|
lora_conflict_keys=["merger"],
|
||||||
@@ -219,10 +233,18 @@ _register_composite_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="gemma4",
|
||||||
|
projector_keys=["model.embed_vision", "model.embed_audio"],
|
||||||
|
vision_model_keys=["vision_tower", "audio_tower"],
|
||||||
|
lora_conflict_keys=["per_layer_projection_norm"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# copied from qwen2vl
|
# copied from qwen2vl
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="glm4v",
|
model_type="glm4v",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger"],
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -231,7 +253,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="glm4v_moe",
|
model_type="glm4v_moe",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger"],
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -240,7 +262,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="glm_ocr",
|
model_type="glm_ocr",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger"],
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -257,7 +279,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="Keye",
|
model_type="Keye",
|
||||||
projector_key="mlp_AR",
|
projector_keys=["mlp_AR"],
|
||||||
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
|
vision_model_keys=["visual.vision_model.patch_embedding", "visual.vision_model.encoder"],
|
||||||
language_model_keys=["model", "lm_head"],
|
language_model_keys=["model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embedding"],
|
lora_conflict_keys=["patch_embedding"],
|
||||||
@@ -292,7 +314,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="minicpmv",
|
model_type="minicpmv",
|
||||||
projector_key="resampler",
|
projector_keys=["resampler"],
|
||||||
vision_model_keys=["vpm"],
|
vision_model_keys=["vpm"],
|
||||||
language_model_keys=["llm"],
|
language_model_keys=["llm"],
|
||||||
)
|
)
|
||||||
@@ -300,7 +322,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="minicpmo",
|
model_type="minicpmo",
|
||||||
projector_key="resampler",
|
projector_keys=["resampler"],
|
||||||
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
|
vision_model_keys=["vpm", "apm", "audio_avg_pooler", "audio_projection_layer", "tts"],
|
||||||
language_model_keys=["llm"],
|
language_model_keys=["llm"],
|
||||||
lora_conflict_keys=["audio_projection_layer"],
|
lora_conflict_keys=["audio_projection_layer"],
|
||||||
@@ -309,7 +331,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="mistral3",
|
model_type="mistral3",
|
||||||
projector_key="model.multi_modal_projector",
|
projector_keys=["model.multi_modal_projector"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -332,7 +354,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen2_5_omni_thinker",
|
model_type="qwen2_5_omni_thinker",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger", "audio_tower.proj"],
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks", "audio_tower"],
|
||||||
language_model_keys=["model", "lm_head"],
|
language_model_keys=["model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -341,7 +363,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen2_vl",
|
model_type="qwen2_vl",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger"],
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -350,7 +372,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen2_5_vl",
|
model_type="qwen2_5_vl",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger"],
|
||||||
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -359,7 +381,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_vl",
|
model_type="qwen3_vl",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger"],
|
||||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -368,7 +390,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_vl_moe",
|
model_type="qwen3_vl_moe",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger"],
|
||||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks", "visual.deepstack_merger_list"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -377,7 +399,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_omni_moe_thinker",
|
model_type="qwen3_omni_moe_thinker",
|
||||||
projector_key="visual.merger",
|
projector_keys=["visual.merger", "audio_tower.proj"],
|
||||||
vision_model_keys=[
|
vision_model_keys=[
|
||||||
"visual.pos_embed",
|
"visual.pos_embed",
|
||||||
"visual.patch_embed",
|
"visual.patch_embed",
|
||||||
@@ -392,7 +414,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_5",
|
model_type="qwen3_5",
|
||||||
projector_key="model.visual.merger",
|
projector_keys=["model.visual.merger"],
|
||||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
@@ -401,7 +423,7 @@ _register_composite_model(
|
|||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="qwen3_5_moe",
|
model_type="qwen3_5_moe",
|
||||||
projector_key="model.visual.merger",
|
projector_keys=["model.visual.merger"],
|
||||||
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
vision_model_keys=["visual.pos_embed", "visual.patch_embed", "visual.blocks"],
|
||||||
language_model_keys=["language_model", "lm_head"],
|
language_model_keys=["language_model", "lm_head"],
|
||||||
lora_conflict_keys=["patch_embed"],
|
lora_conflict_keys=["patch_embed"],
|
||||||
|
|||||||
18
src/llamafactory/train/hyper_parallel/__init__.py
Normal file
18
src/llamafactory/train/hyper_parallel/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .workflow import run_sft
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["run_sft"]
|
||||||
183
src/llamafactory/train/hyper_parallel/workflow.py
Normal file
183
src/llamafactory/train/hyper_parallel/workflow.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||||
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.misc import calculate_tps
|
||||||
|
from ...extras.packages import is_hyper_parallel_available, is_transformers_version_greater_than
|
||||||
|
from ...extras.ploting import plot_loss
|
||||||
|
from ...model import load_model, load_tokenizer
|
||||||
|
from ..callbacks import SaveProcessorCallback
|
||||||
|
from ..sft.metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor
|
||||||
|
from ..trainer_utils import asft_loss_func, create_modelcard_and_push, create_ref_model, dft_loss_func, eaft_loss_func
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||||
|
|
||||||
|
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def run_sft(
|
||||||
|
model_args: "ModelArguments",
|
||||||
|
data_args: "DataArguments",
|
||||||
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
|
finetuning_args: "FinetuningArguments",
|
||||||
|
generating_args: "GeneratingArguments",
|
||||||
|
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||||
|
):
|
||||||
|
if not is_hyper_parallel_available():
|
||||||
|
raise ImportError(
|
||||||
|
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
||||||
|
)
|
||||||
|
|
||||||
|
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
|
||||||
|
HyperParallelArguments,
|
||||||
|
HyperParallelTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module)
|
||||||
|
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train)
|
||||||
|
|
||||||
|
ref_model = None
|
||||||
|
if finetuning_args.use_asft_loss:
|
||||||
|
ref_model = create_ref_model(model_args, finetuning_args)
|
||||||
|
|
||||||
|
data_collator = SFTDataCollatorWith4DAttentionMask(
|
||||||
|
template=template,
|
||||||
|
model=model if not training_args.predict_with_generate else None,
|
||||||
|
pad_to_multiple_of=8 if training_args.do_train else None,
|
||||||
|
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
|
||||||
|
block_diag_attn=model_args.block_diag_attn,
|
||||||
|
attn_implementation=getattr(model.config, "_attn_implementation", None),
|
||||||
|
compute_dtype=model_args.compute_dtype,
|
||||||
|
**tokenizer_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Metric utils
|
||||||
|
metric_module = {}
|
||||||
|
if training_args.predict_with_generate:
|
||||||
|
metric_module["compute_metrics"] = ComputeSimilarity(tokenizer=tokenizer)
|
||||||
|
elif finetuning_args.compute_accuracy:
|
||||||
|
metric_module["compute_metrics"] = ComputeAccuracy()
|
||||||
|
metric_module["preprocess_logits_for_metrics"] = eval_logit_processor
|
||||||
|
|
||||||
|
# Keyword arguments for `model.generate`
|
||||||
|
gen_kwargs = generating_args.to_dict(obey_generation_config=True)
|
||||||
|
if is_transformers_version_greater_than("4.58.0"):
|
||||||
|
extra_ids = getattr(tokenizer, "additional_special_tokens_ids", None)
|
||||||
|
if not isinstance(extra_ids, list):
|
||||||
|
extra_special_tokens = getattr(tokenizer, "_extra_special_tokens", [])
|
||||||
|
string_tokens = [str(t) for t in extra_special_tokens]
|
||||||
|
extra_ids = tokenizer.convert_tokens_to_ids(string_tokens)
|
||||||
|
all_eos_ids = [tokenizer.eos_token_id] + [i for i in extra_ids if i != -1]
|
||||||
|
gen_kwargs["eos_token_id"] = list(dict.fromkeys(all_eos_ids))
|
||||||
|
else:
|
||||||
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
|
|
||||||
|
hp_args = HyperParallelArguments.from_finetuning_args(finetuning_args)
|
||||||
|
|
||||||
|
callbacks = list(callbacks or [])
|
||||||
|
processor = tokenizer_module.get("processor")
|
||||||
|
if processor is not None:
|
||||||
|
callbacks.append(SaveProcessorCallback(processor))
|
||||||
|
|
||||||
|
compute_loss_func = None
|
||||||
|
if finetuning_args.use_dft_loss:
|
||||||
|
compute_loss_func = dft_loss_func
|
||||||
|
elif finetuning_args.use_eaft_loss:
|
||||||
|
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func( # noqa: E731
|
||||||
|
outputs, labels, num_items_in_batch, finetuning_args.eaft_alpha
|
||||||
|
)
|
||||||
|
elif finetuning_args.use_asft_loss:
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
compute_loss_func = partial(asft_loss_func, asft_alpha=finetuning_args.asft_alpha)
|
||||||
|
|
||||||
|
trainer = HyperParallelTrainer(
|
||||||
|
hp_args=hp_args,
|
||||||
|
model=model,
|
||||||
|
args=training_args,
|
||||||
|
finetuning_args=finetuning_args,
|
||||||
|
data_collator=data_collator,
|
||||||
|
callbacks=callbacks,
|
||||||
|
gen_kwargs=gen_kwargs,
|
||||||
|
ref_model=ref_model,
|
||||||
|
compute_loss_func=compute_loss_func,
|
||||||
|
**dataset_module,
|
||||||
|
**tokenizer_module,
|
||||||
|
**metric_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
if finetuning_args.use_badam:
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
|
||||||
|
|
||||||
|
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
|
||||||
|
trainer.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
if training_args.do_train:
|
||||||
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
|
if finetuning_args.include_effective_tokens_per_second:
|
||||||
|
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||||
|
dataset_module["train_dataset"], train_result.metrics, stage="sft"
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
trainer.save_state()
|
||||||
|
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||||
|
keys = ["loss"]
|
||||||
|
if isinstance(dataset_module.get("eval_dataset"), dict):
|
||||||
|
keys += sum(
|
||||||
|
[[f"eval_{key}_loss", f"eval_{key}_accuracy"] for key in dataset_module["eval_dataset"].keys()],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
keys += ["eval_loss", "eval_accuracy"]
|
||||||
|
|
||||||
|
plot_loss(training_args.output_dir, keys=keys)
|
||||||
|
|
||||||
|
if training_args.predict_with_generate:
|
||||||
|
tokenizer.padding_side = "left"
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
if training_args.do_eval:
|
||||||
|
metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
|
||||||
|
trainer.log_metrics("eval", metrics)
|
||||||
|
trainer.save_metrics("eval", metrics)
|
||||||
|
|
||||||
|
# Predict
|
||||||
|
if training_args.do_predict:
|
||||||
|
logger.warning_rank0_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||||
|
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||||
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
|
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens)
|
||||||
|
|
||||||
|
# Create model card
|
||||||
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||||
@@ -24,7 +24,12 @@ from ..data import get_template_and_fix_tokenizer
|
|||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
from ..extras.misc import find_available_port, get_device_name, get_torch_device, infer_optim_dtype
|
||||||
from ..extras.packages import is_mcore_adapter_available, is_ray_available, is_transformers_version_greater_than
|
from ..extras.packages import (
|
||||||
|
is_hyper_parallel_available,
|
||||||
|
is_mcore_adapter_available,
|
||||||
|
is_ray_available,
|
||||||
|
is_transformers_version_greater_than,
|
||||||
|
)
|
||||||
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
from ..hparams import RayArguments, get_infer_args, get_ray_args, get_train_args, read_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback
|
||||||
@@ -71,7 +76,16 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
|
|
||||||
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
callbacks.append(ReporterCallback(model_args, data_args, finetuning_args, generating_args)) # add to last
|
||||||
|
|
||||||
if finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||||
|
if not is_hyper_parallel_available():
|
||||||
|
raise ImportError(
|
||||||
|
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
||||||
|
)
|
||||||
|
from .hyper_parallel import run_sft as run_sft_hp
|
||||||
|
|
||||||
|
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
|
||||||
|
elif finetuning_args.stage in ["pt", "sft", "dpo"] and finetuning_args.use_mca:
|
||||||
if not is_mcore_adapter_available():
|
if not is_mcore_adapter_available():
|
||||||
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
raise ImportError("mcore_adapter is not installed. Please install it with `pip install mcore-adapter`.")
|
||||||
if finetuning_args.stage == "pt":
|
if finetuning_args.stage == "pt":
|
||||||
|
|||||||
@@ -85,6 +85,10 @@ class TrainingArguments:
|
|||||||
default=42,
|
default=42,
|
||||||
metadata={"help": "Random seed that will be set at the beginning of training."},
|
metadata={"help": "Random seed that will be set at the beginning of training."},
|
||||||
)
|
)
|
||||||
|
logging_steps: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "Log metrics every N optimizer steps."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.dist_config = get_plugin_config(self.dist_config)
|
self.dist_config = get_plugin_config(self.dist_config)
|
||||||
|
|||||||
@@ -36,6 +36,12 @@ from ..accelerator.helper import ReduceOp
|
|||||||
from ..accelerator.interface import Dim, DistributedInterface
|
from ..accelerator.interface import Dim, DistributedInterface
|
||||||
from ..config import TrainingArguments
|
from ..config import TrainingArguments
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from ..utils.callbacks import (
|
||||||
|
CallbackHandler,
|
||||||
|
LoggingCallback,
|
||||||
|
TrainerCallback,
|
||||||
|
TrainerState,
|
||||||
|
)
|
||||||
from ..utils.helper import compute_valid_tokens
|
from ..utils.helper import compute_valid_tokens
|
||||||
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
from ..utils.types import BatchInput, HFModel, ModelOutput, Tensor, TorchDataset
|
||||||
from .utils.batching import BatchGenerator
|
from .utils.batching import BatchGenerator
|
||||||
@@ -52,6 +58,7 @@ class BaseTrainer:
|
|||||||
model: HFModel,
|
model: HFModel,
|
||||||
renderer: Renderer,
|
renderer: Renderer,
|
||||||
train_dataset: TorchDataset,
|
train_dataset: TorchDataset,
|
||||||
|
callbacks: list[TrainerCallback] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.args = args
|
self.args = args
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -64,6 +71,7 @@ class BaseTrainer:
|
|||||||
# cached variables
|
# cached variables
|
||||||
self.device = DistributedInterface().current_device
|
self.device = DistributedInterface().current_device
|
||||||
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
self.dp_size = DistributedInterface().get_world_size(Dim.DP)
|
||||||
|
self.cp_size = DistributedInterface().get_world_size(Dim.CP)
|
||||||
self.model_input_names = self.renderer.processor.model_input_names
|
self.model_input_names = self.renderer.processor.model_input_names
|
||||||
|
|
||||||
self._create_batch_generator()
|
self._create_batch_generator()
|
||||||
@@ -99,6 +107,29 @@ class BaseTrainer:
|
|||||||
self._init_optimizer()
|
self._init_optimizer()
|
||||||
self._init_lr_scheduler()
|
self._init_lr_scheduler()
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
self.callback_handler = CallbackHandler([LoggingCallback()], trainer=self)
|
||||||
|
for cb in callbacks or []:
|
||||||
|
self.callback_handler.add_callback(cb)
|
||||||
|
|
||||||
|
# Callbacks: TrainerState tracks progress across the full run.
|
||||||
|
self.state = TrainerState(num_training_steps=self.num_training_steps)
|
||||||
|
|
||||||
|
if self.args.dist_config is not None and self.args.dist_config.get("cp_size", 1) > 1:
|
||||||
|
# qwen3.5 is not supported because of the different attention implementation, which will be supported in the future.
|
||||||
|
if model.config.model_type == "qwen3_5":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Sequence parallel is not supported for qwen3.5 model due to its different attention implementation, which will be supported in the future."
|
||||||
|
)
|
||||||
|
from ..plugins.model_plugins.parallelization.sequence_parallel import SequenceParallelModelPlugin
|
||||||
|
|
||||||
|
if model.config._attn_implementation != "flash_attention_2":
|
||||||
|
logger.warning_rank0(
|
||||||
|
"Sequence parallelism is optimized for flash attention only. Replace the attention implementation to flash_attention_2."
|
||||||
|
)
|
||||||
|
model.config._attn_implementation = "flash_attention_2"
|
||||||
|
SequenceParallelModelPlugin(self.args.dist_config.get("cp_mode", "ulysses"))(model, self.args.dist_config)
|
||||||
|
|
||||||
def _create_batch_generator(self) -> None:
|
def _create_batch_generator(self) -> None:
|
||||||
self.train_batch_generator = BatchGenerator(
|
self.train_batch_generator = BatchGenerator(
|
||||||
dataset=self.train_dataset,
|
dataset=self.train_dataset,
|
||||||
@@ -157,7 +188,7 @@ class BaseTrainer:
|
|||||||
"""
|
"""
|
||||||
batch_size, _ = batch["labels"].shape
|
batch_size, _ = batch["labels"].shape
|
||||||
model_inputs = {
|
model_inputs = {
|
||||||
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if k in self.model_input_names
|
k: v.to(self.device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor)
|
||||||
}
|
}
|
||||||
labels = batch["labels"].to(self.device, non_blocking=True)
|
labels = batch["labels"].to(self.device, non_blocking=True)
|
||||||
outputs: ModelOutput = model(**model_inputs)
|
outputs: ModelOutput = model(**model_inputs)
|
||||||
@@ -174,16 +205,31 @@ class BaseTrainer:
|
|||||||
def fit(self) -> None:
|
def fit(self) -> None:
|
||||||
"""Train the model."""
|
"""Train the model."""
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
self.callback_handler.on_train_begin(self.args, self.state)
|
||||||
for epoch in range(self.args.num_train_epochs):
|
for epoch in range(self.args.num_train_epochs):
|
||||||
|
self.state.epoch = epoch
|
||||||
self.train_batch_generator.set_epoch(epoch)
|
self.train_batch_generator.set_epoch(epoch)
|
||||||
|
self.callback_handler.on_epoch_begin(self.args, self.state)
|
||||||
|
|
||||||
for micro_batches in self.train_batch_generator:
|
for micro_batches in self.train_batch_generator:
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
|
self.state.global_step = self.global_step
|
||||||
|
self.callback_handler.on_step_begin(self.args, self.state)
|
||||||
|
|
||||||
step_loss = 0
|
step_loss = 0
|
||||||
step_valid_tokens = compute_valid_tokens(micro_batches)
|
step_valid_tokens = compute_valid_tokens(micro_batches)
|
||||||
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
step_valid_tokens = DistributedInterface().all_reduce(step_valid_tokens, op=ReduceOp.SUM)
|
||||||
num_micro = len(micro_batches)
|
num_micro = len(micro_batches)
|
||||||
for i, micro_batch in enumerate(micro_batches):
|
for i, micro_batch in enumerate(micro_batches):
|
||||||
loss = self.compute_loss(micro_batch)
|
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
|
||||||
|
from ..plugins.model_plugins.parallelization.sequence_parallel import (
|
||||||
|
SequenceParallelLossPlugin,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = SequenceParallelLossPlugin("sequence_parallel_loss")(self.model, micro_batch)
|
||||||
|
else:
|
||||||
|
loss = self.compute_loss(micro_batch)
|
||||||
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
mini_step_valid_tokens = compute_valid_tokens([micro_batch])
|
||||||
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
# fsdp uses mean reduction so we need to scale the loss by dp_size
|
||||||
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
loss = loss * mini_step_valid_tokens * self.dp_size / (step_valid_tokens + 1e-6)
|
||||||
@@ -200,7 +246,24 @@ class BaseTrainer:
|
|||||||
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
# deepspeed: engine.step() already ran inside backward at the sync boundary
|
||||||
grad_norm = self._deepspeed_engine.get_grad_norm()
|
grad_norm = self._deepspeed_engine.get_grad_norm()
|
||||||
else:
|
else:
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm).item()
|
if self.args.dist_config and self.args.dist_config.get("cp_size", 1) > 1:
|
||||||
|
from torch.nn.utils.clip_grad import _clip_grads_with_norm_, _get_total_norm
|
||||||
|
|
||||||
|
parameters = self.model.parameters()
|
||||||
|
if isinstance(parameters, torch.Tensor):
|
||||||
|
parameters = [parameters]
|
||||||
|
else:
|
||||||
|
parameters = list(parameters)
|
||||||
|
grads = [p.grad for p in parameters if p.grad is not None]
|
||||||
|
grad_norm = _get_total_norm(grads)
|
||||||
|
grad_norm = grad_norm.to(self.device)
|
||||||
|
_clip_grads_with_norm_(parameters, self.args.max_grad_norm, grad_norm)
|
||||||
|
if isinstance(grad_norm, torch.distributed._tensor.DTensor):
|
||||||
|
grad_norm = grad_norm.full_tensor().item()
|
||||||
|
else:
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self.model.parameters(), self.args.max_grad_norm
|
||||||
|
).item()
|
||||||
|
|
||||||
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
# isfinite(): argument 'input' (position 1) must be Tensor, not float
|
||||||
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
if not torch.isfinite(torch.tensor(grad_norm)): # type: ignore # pyright: ignore [reportUnknownReturnType]
|
||||||
@@ -213,14 +276,41 @@ class BaseTrainer:
|
|||||||
|
|
||||||
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
step_loss, grad_norm = DistributedInterface().all_reduce([step_loss, grad_norm])
|
||||||
DistributedInterface().sync()
|
DistributedInterface().sync()
|
||||||
if DistributedInterface().get_rank() == 0:
|
|
||||||
print(f"Epoch {epoch}, Step {self.global_step}, Loss: {step_loss:.4f}, Grad Norm: {grad_norm:.4f}")
|
# Update state with step metrics
|
||||||
|
current_lr = (
|
||||||
|
self.lr_scheduler.get_last_lr()[0]
|
||||||
|
if hasattr(self.lr_scheduler, "get_last_lr")
|
||||||
|
else self.args.learning_rate
|
||||||
|
)
|
||||||
|
self.state.loss = step_loss
|
||||||
|
self.state.grad_norm = grad_norm
|
||||||
|
self.state.learning_rate = current_lr
|
||||||
|
|
||||||
|
self.callback_handler.on_step_end(self.args, self.state)
|
||||||
|
|
||||||
|
# Logging: trainer decides when to log
|
||||||
|
if self.global_step % self.args.logging_steps == 0:
|
||||||
|
logs = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"step": self.global_step,
|
||||||
|
"loss": step_loss,
|
||||||
|
"grad_norm": grad_norm,
|
||||||
|
"learning_rate": current_lr,
|
||||||
|
}
|
||||||
|
self.callback_handler.on_log(self.args, self.state, logs)
|
||||||
|
|
||||||
# Check if max_steps is reached
|
# Check if max_steps is reached
|
||||||
if self.global_step >= self.num_training_steps:
|
if self.global_step >= self.num_training_steps:
|
||||||
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
logger.info_rank0(f"Reached max_steps ({self.num_training_steps}), stopping training.")
|
||||||
|
self.callback_handler.on_epoch_end(self.args, self.state)
|
||||||
|
self.callback_handler.on_train_end(self.args, self.state)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
self.callback_handler.on_epoch_end(self.args, self.state)
|
||||||
|
|
||||||
|
self.callback_handler.on_train_end(self.args, self.state)
|
||||||
|
|
||||||
def save_model(self) -> None:
|
def save_model(self) -> None:
|
||||||
"""Save the model."""
|
"""Save the model."""
|
||||||
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
if self.args.dist_config is not None and self.args.dist_config.name in ("deepspeed", "fsdp2"):
|
||||||
@@ -234,3 +324,5 @@ class BaseTrainer:
|
|||||||
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
model_to_save.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||||
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
self.renderer.processor.save_pretrained(self.args.output_dir, max_shard_size="4GB")
|
||||||
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
logger.info_rank0(f"Model saved to {self.args.output_dir}")
|
||||||
|
|
||||||
|
self.callback_handler.on_save(self.args, self.state)
|
||||||
|
|||||||
@@ -140,6 +140,9 @@ class ModelEngine:
|
|||||||
**init_kwargs,
|
**init_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
init_mode = self.args.init_config.name if self.args.init_config is not None else "init_on_default"
|
||||||
|
model._init_mode = init_mode
|
||||||
|
|
||||||
if self.args.peft_config is None:
|
if self.args.peft_config is None:
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
logger.info_rank0("Fine-tuning mode: full tuning")
|
logger.info_rank0("Fine-tuning mode: full tuning")
|
||||||
@@ -147,6 +150,9 @@ class ModelEngine:
|
|||||||
else:
|
else:
|
||||||
logger.info_rank0("Inference the original model")
|
logger.info_rank0("Inference the original model")
|
||||||
else:
|
else:
|
||||||
|
if self.args.peft_config.name == "lora" and init_mode == "init_on_meta":
|
||||||
|
raise ValueError("Currently lora stage does not support loading model by meta.")
|
||||||
|
|
||||||
from ..plugins.model_plugins.peft import PeftPlugin
|
from ..plugins.model_plugins.peft import PeftPlugin
|
||||||
|
|
||||||
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
model = PeftPlugin(self.args.peft_config.name)(model, self.args.peft_config, self.is_train)
|
||||||
|
|||||||
@@ -146,6 +146,8 @@ class Renderer:
|
|||||||
for sample in samples:
|
for sample in samples:
|
||||||
if "messages" in sample:
|
if "messages" in sample:
|
||||||
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
model_input = self.render_messages(sample["messages"], sample.get("tools"))
|
||||||
|
if "position_ids" not in model_input:
|
||||||
|
model_input["position_ids"] = list(range(1, len(model_input["input_ids"]) + 1))
|
||||||
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
elif "chosen_messages" in sample and "rejected_messages" in sample:
|
||||||
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
chosen_input = self.render_messages(sample["chosen_messages"], sample.get("tools"))
|
||||||
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
rejected_input = self.render_messages(sample["rejected_messages"], sample.get("tools"))
|
||||||
|
|||||||
@@ -0,0 +1,59 @@
|
|||||||
|
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the Bytedance's verl library.
|
||||||
|
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def all_to_all_tensor(
|
||||||
|
local_input: Tensor,
|
||||||
|
scatter_dim: int,
|
||||||
|
gather_dim: int,
|
||||||
|
group: Optional[dist.ProcessGroup] = None,
|
||||||
|
):
|
||||||
|
seq_world_size = dist.get_world_size(group)
|
||||||
|
input_list = [t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)]
|
||||||
|
output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
|
||||||
|
dist.all_to_all(output_list, input_list, group=group)
|
||||||
|
return torch.cat(output_list, dim=gather_dim).contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class SeqAllToAll4D(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx: Any,
|
||||||
|
group: dist.ProcessGroup,
|
||||||
|
local_input: Tensor,
|
||||||
|
scatter_dim: int,
|
||||||
|
gather_dim: int,
|
||||||
|
) -> Tensor:
|
||||||
|
ctx.group = group
|
||||||
|
ctx.scatter_dim = scatter_dim
|
||||||
|
ctx.gather_dim = gather_dim
|
||||||
|
return all_to_all_tensor(local_input, scatter_dim, gather_dim, group)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, *grad_output: Tensor) -> tuple[None, Tensor, None, None]:
|
||||||
|
return (
|
||||||
|
None,
|
||||||
|
all_to_all_tensor(grad_output[0], ctx.gather_dim, ctx.scatter_dim, ctx.group),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
@@ -0,0 +1,199 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
from ....accelerator.interface import Dim, DistributedInterface
|
||||||
|
from ....utils import logging
|
||||||
|
from ....utils.plugin import BasePlugin
|
||||||
|
from ....utils.types import ModelOutput
|
||||||
|
from .ulysses import (
|
||||||
|
UlyssesAttention,
|
||||||
|
get_ulysses_sequence_parallel_group,
|
||||||
|
get_ulysses_sequence_parallel_rank,
|
||||||
|
get_ulysses_sequence_parallel_world_size,
|
||||||
|
set_ulysses_sequence_parallel_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelModelPlugin(BasePlugin):
|
||||||
|
def __call__(self, model, model_args):
|
||||||
|
return super().__call__(model, model_args)
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceParallelLossPlugin(BasePlugin):
|
||||||
|
def __call__(self, model, inputs, *args, **kwargs):
|
||||||
|
return super().__call__(model, inputs, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def new_flash_attn_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
sequence_parallel_size=1,
|
||||||
|
dropout=0,
|
||||||
|
deterministic=False,
|
||||||
|
is_causal=True,
|
||||||
|
group=None,
|
||||||
|
mode="ulysses",
|
||||||
|
attn_fn=None,
|
||||||
|
target_dtype=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if mode == "ulysses":
|
||||||
|
dist_attn = UlyssesAttention(sequence_process_group=group, attn_fn=attn_fn)
|
||||||
|
attn_output = dist_attn(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
query_length=query_states.shape[1] * sequence_parallel_size,
|
||||||
|
deterministic=deterministic,
|
||||||
|
dropout_p=dropout,
|
||||||
|
causal=is_causal,
|
||||||
|
position_ids=kwargs.get("position_ids", None),
|
||||||
|
target_dtype=target_dtype,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Other sequence parallel modes are to be implemented.")
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
@SequenceParallelModelPlugin("ulysses").register()
|
||||||
|
def apply_sequence_parallel(model, model_args):
|
||||||
|
# Replace _flash_attention_forward with new_flash_attn_forward
|
||||||
|
module = sys.modules[model.__module__]
|
||||||
|
cp_size = model_args.get("cp_size", 1)
|
||||||
|
|
||||||
|
set_ulysses_sequence_parallel_group(DistributedInterface().get_group(Dim.CP))
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_attention_heads
|
||||||
|
except AttributeError:
|
||||||
|
num_attention_heads, num_key_value_heads = (
|
||||||
|
model.config.text_config.num_attention_heads,
|
||||||
|
model.config.text_config.num_key_value_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_attention_heads % cp_size == 0, "num_attention_heads must be divisible by cp_size"
|
||||||
|
assert num_key_value_heads % cp_size == 0 or cp_size % num_key_value_heads == 0, (
|
||||||
|
"num_key_value_heads must be divisible by cp_size"
|
||||||
|
)
|
||||||
|
|
||||||
|
origin_attn = transformers.modeling_flash_attention_utils._flash_attention_forward
|
||||||
|
new_flash_attention_forward = partial(
|
||||||
|
new_flash_attn_forward,
|
||||||
|
group=get_ulysses_sequence_parallel_group(),
|
||||||
|
mode="ulysses",
|
||||||
|
attn_fn=origin_attn,
|
||||||
|
sequence_parallel_size=cp_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
for module_name, module in list(sys.modules.items()):
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
hasattr(module, "__file__")
|
||||||
|
and "transformers" in module.__file__
|
||||||
|
and getattr(module._flash_attention_forward, "__name__", "") == "_flash_attention_forward"
|
||||||
|
):
|
||||||
|
module._flash_attention_forward = new_flash_attention_forward
|
||||||
|
logger.info_rank0(
|
||||||
|
f"Replaced _flash_attention_forward in module {module_name} with new_flash_attn_forward for sequence parallel."
|
||||||
|
)
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def padding_and_split_data(data, device_mesh=None):
|
||||||
|
if device_mesh is not None:
|
||||||
|
cp_size = device_mesh["cp"].size()
|
||||||
|
cp_rank = device_mesh["cp"].get_local_rank()
|
||||||
|
cp_group = device_mesh["cp"].get_group()
|
||||||
|
for k, v in data.items():
|
||||||
|
if isinstance(v, torch.Tensor) and v.ndim > 1:
|
||||||
|
data_len = torch.tensor(v.shape[-1], device=v.device, dtype=torch.int64)
|
||||||
|
global_data_len = [torch.empty_like(data_len) for _ in range(cp_size)]
|
||||||
|
dist.all_gather(global_data_len, data_len, group=cp_group)
|
||||||
|
max_data_len = max(global_data_len)
|
||||||
|
pad_size = max_data_len - v.shape[-1] + (cp_size - max_data_len % cp_size) % cp_size
|
||||||
|
if k == "labels":
|
||||||
|
pad_value = -100
|
||||||
|
elif k == "loss_weights":
|
||||||
|
pad_value = 0.0
|
||||||
|
else:
|
||||||
|
pad_value = 0
|
||||||
|
pad_data = F.pad(v, (0, pad_size), value=pad_value)
|
||||||
|
data[k] = torch.chunk(pad_data, chunks=cp_size, dim=-1)[cp_rank].contiguous()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
@SequenceParallelLossPlugin("sequence_parallel_loss").register()
|
||||||
|
def sequence_parallel_loss(model, model_inputs):
|
||||||
|
device_mesh = DistributedInterface().get_device_mesh(Dim.CP)
|
||||||
|
|
||||||
|
model_inputs = {
|
||||||
|
k: v.to(dist.get_rank(), non_blocking=True) for k, v in model_inputs.items() if isinstance(v, torch.Tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
model_inputs = padding_and_split_data(model_inputs, device_mesh)
|
||||||
|
|
||||||
|
batch_size, _ = model_inputs["labels"].shape
|
||||||
|
|
||||||
|
outputs: ModelOutput = model(**model_inputs)
|
||||||
|
|
||||||
|
logits = outputs.logits.float()
|
||||||
|
|
||||||
|
labels = model_inputs["labels"]
|
||||||
|
|
||||||
|
cp_group = get_ulysses_sequence_parallel_group()
|
||||||
|
cp_world_size = get_ulysses_sequence_parallel_world_size(cp_group)
|
||||||
|
cp_rank = get_ulysses_sequence_parallel_rank(cp_group)
|
||||||
|
|
||||||
|
# use all_gather to collect labels from all sequence parallel processes
|
||||||
|
global_labels = [torch.empty_like(labels) for _ in range(cp_world_size)]
|
||||||
|
dist.all_gather(global_labels, labels, group=cp_group)
|
||||||
|
labels = torch.cat(global_labels, dim=1).contiguous()
|
||||||
|
shift_labels = labels[..., 1:].view(-1).contiguous()
|
||||||
|
shift_labels = F.pad(shift_labels, (0, 1), value=-100)
|
||||||
|
shift_labels = torch.chunk(shift_labels, chunks=cp_world_size, dim=-1)[cp_rank].contiguous()
|
||||||
|
|
||||||
|
# use all_gather to collect loss_weights from all sequence parallel processes
|
||||||
|
loss_weights = model_inputs["loss_weights"]
|
||||||
|
global_loss_weights = [torch.empty_like(loss_weights) for _ in range(cp_world_size)]
|
||||||
|
dist.all_gather(global_loss_weights, loss_weights, group=cp_group)
|
||||||
|
shift_loss_weights = torch.cat(global_loss_weights, dim=1).contiguous()
|
||||||
|
shift_loss_weights = shift_loss_weights[..., 1:].contiguous()
|
||||||
|
|
||||||
|
shift_logits = logits.view(shift_labels.size(0), -1).contiguous()
|
||||||
|
|
||||||
|
# use all_gather to collect log_probs from all sequence parallel processes
|
||||||
|
log_probs = -F.cross_entropy(shift_logits, shift_labels, reduction="none").view(batch_size, -1)
|
||||||
|
global_log_probs = dist.nn.all_gather(log_probs, group=cp_group)
|
||||||
|
global_log_probs = torch.cat(global_log_probs, dim=1).contiguous()
|
||||||
|
log_probs = global_log_probs[..., :-1].contiguous()
|
||||||
|
|
||||||
|
loss = (-log_probs * shift_loss_weights).sum() / (shift_loss_weights.sum() + 1e-6)
|
||||||
|
|
||||||
|
return loss
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
# Copyright 2025 Bytedance Ltd. and/or its affiliates. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the Bytedance's verl library.
|
||||||
|
# https://github.com/verl-project/verl/blob/77476af84cc074edf5a6437f8d5ea418d7a54916/verl/utils/ulysses.py
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from .seq_comm import SeqAllToAll4D
|
||||||
|
|
||||||
|
|
||||||
|
_ULYSSES_SEQUENCE_PARALLEL_GROUP = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_ulysses_sequence_parallel_group(group: dist.ProcessGroup):
|
||||||
|
"""Set ulysses sequence parallel process group."""
|
||||||
|
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||||
|
_ULYSSES_SEQUENCE_PARALLEL_GROUP = group
|
||||||
|
|
||||||
|
|
||||||
|
def get_ulysses_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
|
||||||
|
"""Get ulysses sequence parallel process group."""
|
||||||
|
global _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||||
|
return _ULYSSES_SEQUENCE_PARALLEL_GROUP
|
||||||
|
|
||||||
|
|
||||||
|
def get_ulysses_sequence_parallel_world_size(group: ProcessGroup = None) -> int:
|
||||||
|
"""Get ulysses sequence parallel world size."""
|
||||||
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||||
|
return dist.get_world_size(group) if group else 1
|
||||||
|
|
||||||
|
|
||||||
|
def get_ulysses_sequence_parallel_rank(group: ProcessGroup = None) -> int:
|
||||||
|
"""Get ulysses sequence parallel rank."""
|
||||||
|
group = get_ulysses_sequence_parallel_group() if group is None else group
|
||||||
|
return dist.get_rank(group) if group else 0
|
||||||
|
|
||||||
|
|
||||||
|
class UlyssesAttention(torch.nn.Module):
|
||||||
|
"""Initialization.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
local_attention (Module): local attention with q,k,v
|
||||||
|
sequence_process_group (ProcessGroup): sequence parallel process group
|
||||||
|
scatter_idx (int): scatter_idx for all2all comm
|
||||||
|
gather_idx (int): gather_idx for all2all comm
|
||||||
|
attn_type (AttnType): attention type enum
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sequence_process_group: dist.ProcessGroup = None,
|
||||||
|
scatter_idx: int = 2,
|
||||||
|
gather_idx: int = 1,
|
||||||
|
attn_fn: Optional[callable] = None,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.spg = sequence_process_group
|
||||||
|
self.scatter_idx = scatter_idx
|
||||||
|
self.gather_idx = gather_idx
|
||||||
|
self.attn_fn = attn_fn
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: Tensor,
|
||||||
|
key: Tensor,
|
||||||
|
value: Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
query_length: int,
|
||||||
|
dropout_p=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
causal=True,
|
||||||
|
deterministic=False,
|
||||||
|
target_dtype=None,
|
||||||
|
*args: Any,
|
||||||
|
) -> Tensor:
|
||||||
|
"""Forward.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
query (Tensor): query input to the layer
|
||||||
|
key (Tensor): key input to the layer
|
||||||
|
value (Tensor): value input to the layer
|
||||||
|
attention_mask (Tensor): attention mask for the layer
|
||||||
|
query_length (int): the length of the query sequence
|
||||||
|
dropout_p (float, optional): dropout probability. Defaults to 0.0.
|
||||||
|
softmax_scale (float, optional): scale factor for softmax. Defaults to None,
|
||||||
|
position_ids (torch.Tensor, optional): position ids for the attention. Defaults to None.
|
||||||
|
causal (bool, optional): whether to apply causal mask. Defaults to True.
|
||||||
|
deterministic (bool, optional): whether to apply dropout in deterministic way. Defaults to False.
|
||||||
|
target_dtype (torch.dtype, optional): target dtype for attention output. Defaults to None.
|
||||||
|
args: other args
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
* output (Tensor): context output
|
||||||
|
"""
|
||||||
|
# TODO Merge three alltoall calls into one
|
||||||
|
# TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together!
|
||||||
|
# in shape : e.g., [s/p:h:]
|
||||||
|
# (bs, seq_len/N, head_cnt, head_size) -> (bs, seq_len, head_cnt/N, head_size)
|
||||||
|
|
||||||
|
# scatter 2, gather 1
|
||||||
|
q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
|
||||||
|
k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
|
||||||
|
v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)
|
||||||
|
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = q.shape[-1] ** -0.5
|
||||||
|
|
||||||
|
if attention_mask is None:
|
||||||
|
if position_ids is not None:
|
||||||
|
attention_mask = torch.ones_like(position_ids).to(torch.int64)
|
||||||
|
else:
|
||||||
|
attention_mask = torch.ones(q.shape[0], q.shape[1], dtype=torch.int64, device=q.device)
|
||||||
|
else:
|
||||||
|
attention_mask = attention_mask.to(torch.int64)
|
||||||
|
|
||||||
|
global_attention_mask = [
|
||||||
|
torch.empty_like(attention_mask) for _ in range(get_ulysses_sequence_parallel_world_size(self.spg))
|
||||||
|
]
|
||||||
|
dist.all_gather(global_attention_mask, attention_mask, group=self.spg)
|
||||||
|
attention_mask = torch.cat(global_attention_mask, dim=1)
|
||||||
|
|
||||||
|
context_layer = self.attn_fn(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attention_mask,
|
||||||
|
query_length=query_length,
|
||||||
|
is_causal=causal,
|
||||||
|
dropout=dropout_p,
|
||||||
|
position_ids=position_ids,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
deterministic=deterministic,
|
||||||
|
target_dtype=target_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(context_layer, tuple):
|
||||||
|
context_layer = context_layer[0]
|
||||||
|
|
||||||
|
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
|
||||||
|
# scatter 1, gather 2
|
||||||
|
output = SeqAllToAll4D.apply(self.spg, context_layer, self.gather_idx, self.scatter_idx)
|
||||||
|
|
||||||
|
# out e.g., [s/p::h]
|
||||||
|
return output
|
||||||
@@ -150,9 +150,6 @@ def load_adapter(model: HFModel, adapter_name_or_path: Union[list[str], str], is
|
|||||||
|
|
||||||
@PeftPlugin("lora").register()
|
@PeftPlugin("lora").register()
|
||||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool = False) -> HFModel:
|
||||||
if model.device.type == "meta":
|
|
||||||
raise ValueError("Currently lora stage does not support loading model by meta.")
|
|
||||||
|
|
||||||
adapter_name_or_path = config.get("adapter_name_or_path")
|
adapter_name_or_path = config.get("adapter_name_or_path")
|
||||||
|
|
||||||
if adapter_name_or_path:
|
if adapter_name_or_path:
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import gc
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from peft.tuners.lora import LoraLayer
|
from peft.tuners.lora import LoraLayer
|
||||||
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict, set_model_state_dict
|
||||||
@@ -84,10 +85,7 @@ class FSDP2Engine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.device_mesh is not None:
|
if self.device_mesh is not None:
|
||||||
try:
|
self.fsdp_mesh = self.device_mesh
|
||||||
self.fsdp_mesh = self.device_mesh["dp"]
|
|
||||||
except Exception:
|
|
||||||
self.fsdp_mesh = self.device_mesh
|
|
||||||
|
|
||||||
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
|
logger.info(f"Using Device Mesh: {self.fsdp_mesh}")
|
||||||
else:
|
else:
|
||||||
@@ -244,23 +242,57 @@ class FSDP2Engine:
|
|||||||
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
logger.info(f"Restored {len(saved_buffers)} non-persistent buffers")
|
||||||
|
|
||||||
def shard_model(self, model: HFModel) -> HFModel:
|
def shard_model(self, model: HFModel) -> HFModel:
|
||||||
if model.device.type == "meta":
|
init_mode = getattr(model, "_init_mode", "init_on_default")
|
||||||
|
|
||||||
|
if init_mode == "init_on_rank0":
|
||||||
|
if getattr(model.config, "tie_word_embeddings", False):
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info("init_on_rank0 detected: sharding then scattering Rank 0 CPU weights.")
|
||||||
|
full_sd = {k: v.clone() for k, v in model.state_dict().items()}
|
||||||
|
else:
|
||||||
|
full_sd = {}
|
||||||
|
|
||||||
|
# Reuse existing helper to save persistent=False buffers (e.g. inv_freq) before shard
|
||||||
|
saved_buffers = self._save_non_persistent_buffers(model) if self.rank == 0 else {}
|
||||||
|
|
||||||
|
model = self.prepare_model(model)
|
||||||
|
|
||||||
|
device = get_current_accelerator()
|
||||||
|
model.to_empty(device=device)
|
||||||
|
|
||||||
|
# Scatter params from Rank 0 into all DTensor shards
|
||||||
|
# Broadcast the full state dict from the global rank-0 process to all ranks in this group.
|
||||||
|
options = StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True)
|
||||||
|
set_model_state_dict(model, full_sd, options=options)
|
||||||
|
|
||||||
|
# Broadcast and restore non-persistent buffers
|
||||||
|
buffers_to_sync = [saved_buffers]
|
||||||
|
dist.broadcast_object_list(buffers_to_sync, src=0, group=self.fsdp_mesh.get_group())
|
||||||
|
self._restore_non_persistent_buffers(model, buffers_to_sync[0])
|
||||||
|
|
||||||
|
if self.rank == 0:
|
||||||
|
logger.info("init_on_rank0 sync complete.")
|
||||||
|
|
||||||
|
elif init_mode == "init_on_meta":
|
||||||
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
non_persistent_buffers = self._save_non_persistent_buffers(model)
|
||||||
|
|
||||||
if getattr(model.config, "tie_word_embeddings", None):
|
if getattr(model.config, "tie_word_embeddings", False):
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
model = self.prepare_model(model)
|
model = self.prepare_model(model)
|
||||||
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
model = self.materialize_and_load(model, hf_model_path=model.config.name_or_path, dcp_path=self.dcp_path)
|
||||||
|
|
||||||
# fix tied broken for no-fsdp-wrap case
|
# fix tied broken for no-fsdp-wrap case
|
||||||
if getattr(model.config, "tie_word_embeddings", None):
|
if getattr(model.config, "tie_word_embeddings", False):
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
self._restore_non_persistent_buffers(model, non_persistent_buffers)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model = self.prepare_model(model)
|
model = self.prepare_model(model)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
def _load_from_dcp(self, model: HFModel, dcp_path: str):
|
||||||
|
|||||||
24
src/llamafactory/v1/utils/callbacks/__init__.py
Normal file
24
src/llamafactory/v1/utils/callbacks/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .logging_callback import LoggingCallback
|
||||||
|
from .trainer_callback import CallbackHandler, TrainerCallback, TrainerState
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CallbackHandler",
|
||||||
|
"LoggingCallback",
|
||||||
|
"TrainerCallback",
|
||||||
|
"TrainerState",
|
||||||
|
]
|
||||||
64
src/llamafactory/v1/utils/callbacks/logging_callback.py
Normal file
64
src/llamafactory/v1/utils/callbacks/logging_callback.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from .. import logging
|
||||||
|
from .trainer_callback import TrainerCallback, TrainerState
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...config import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingCallback(TrainerCallback):
|
||||||
|
"""Logs training metrics to stdout on rank-0 and appends to ``state.log_history``.
|
||||||
|
|
||||||
|
On each logging step the entry is also persisted as a JSON line in
|
||||||
|
``<output_dir>/trainer_log.jsonl`` so that training history survives crashes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_log(
|
||||||
|
self,
|
||||||
|
args: TrainingArguments,
|
||||||
|
state: TrainerState,
|
||||||
|
logs: dict[str, Any],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
# Persist in history regardless of rank
|
||||||
|
state.log_history.append(dict(logs))
|
||||||
|
|
||||||
|
# Everything below is rank-0 only
|
||||||
|
from ...accelerator.interface import DistributedInterface # lazy import
|
||||||
|
|
||||||
|
if DistributedInterface().get_rank() != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Human-readable output to stdout
|
||||||
|
display_logs = {**logs, "total_steps": state.num_training_steps}
|
||||||
|
parts = ", ".join(f"{k}: {v:.4f}" if isinstance(v, float) else f"{k}: {v}" for k, v in display_logs.items())
|
||||||
|
logger.info_rank0(parts)
|
||||||
|
|
||||||
|
# Append to JSONL log file in output_dir
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
log_file = os.path.join(args.output_dir, "trainer_log.jsonl")
|
||||||
|
with open(log_file, "a", encoding="utf-8") as f:
|
||||||
|
f.write(json.dumps(display_logs, ensure_ascii=False) + "\n")
|
||||||
147
src/llamafactory/v1/utils/callbacks/trainer_callback.py
Normal file
147
src/llamafactory/v1/utils/callbacks/trainer_callback.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...config import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainerState:
|
||||||
|
"""A read-only snapshot of training progress passed to every callback hook.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
epoch: Current epoch (0-indexed).
|
||||||
|
global_step: Number of optimizer steps completed so far.
|
||||||
|
num_training_steps: Total number of optimizer steps planned.
|
||||||
|
loss: Scalar loss value of the most recent step.
|
||||||
|
grad_norm: Gradient-norm value of the most recent step.
|
||||||
|
learning_rate: Current learning rate seen by the optimizer.
|
||||||
|
log_history: List of per-step log dicts emitted by ``LoggingCallback``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
epoch: int = 0
|
||||||
|
global_step: int = 0
|
||||||
|
num_training_steps: int = 0
|
||||||
|
loss: float = 0.0
|
||||||
|
grad_norm: float = 0.0
|
||||||
|
learning_rate: float = 0.0
|
||||||
|
log_history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerCallback:
|
||||||
|
"""Abstract base class for training callbacks.
|
||||||
|
|
||||||
|
Subclass and override whichever hooks you need. All hooks receive:
|
||||||
|
|
||||||
|
- ``args`` – the :class:`~llamafactory.v1.config.TrainingArguments`.
|
||||||
|
- ``state`` – a :class:`TrainerState` snapshot (read-only).
|
||||||
|
- ``**kwargs`` – extra keyword arguments (model, optimizer, …).
|
||||||
|
|
||||||
|
Callbacks are *observers*: they should NOT mutate training flow.
|
||||||
|
|
||||||
|
Hook call order::
|
||||||
|
|
||||||
|
on_train_begin
|
||||||
|
for each epoch:
|
||||||
|
on_epoch_begin
|
||||||
|
for each step:
|
||||||
|
on_step_begin
|
||||||
|
(forward / backward / optimizer.step)
|
||||||
|
on_step_end
|
||||||
|
[on_log] ← if this step is a logging step
|
||||||
|
on_epoch_end
|
||||||
|
on_train_end
|
||||||
|
"""
|
||||||
|
|
||||||
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called once before the first training step."""
|
||||||
|
|
||||||
|
def on_train_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called once after the last training step."""
|
||||||
|
|
||||||
|
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called at the beginning of each epoch."""
|
||||||
|
|
||||||
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called at the end of each epoch."""
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called before the forward/backward pass of each optimizer step."""
|
||||||
|
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called after the optimizer step."""
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any], **kwargs: Any) -> None:
|
||||||
|
"""Called when the trainer emits a log entry."""
|
||||||
|
|
||||||
|
def on_save(self, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
"""Called after the model checkpoint has been written to disk."""
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackHandler:
|
||||||
|
"""Owns a list of :class:`TrainerCallback` instances and fans out hook calls.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
handler = CallbackHandler([LoggingCallback(), MyWandbCallback()], trainer=trainer)
|
||||||
|
handler.on_train_begin(args, state)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, callbacks: list[TrainerCallback] | None = None, trainer: Any = None) -> None:
|
||||||
|
self.callbacks: list[TrainerCallback] = list(callbacks or [])
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
|
def add_callback(self, callback: TrainerCallback) -> None:
|
||||||
|
"""Append a callback to the handler."""
|
||||||
|
self.callbacks.append(callback)
|
||||||
|
|
||||||
|
def _call(self, event: str, args: TrainingArguments, state: TrainerState, **kwargs: Any) -> None:
|
||||||
|
if self.trainer is not None:
|
||||||
|
kwargs.setdefault("model", getattr(self.trainer, "model", None))
|
||||||
|
kwargs.setdefault("optimizer", getattr(self.trainer, "optimizer", None))
|
||||||
|
kwargs.setdefault("lr_scheduler", getattr(self.trainer, "lr_scheduler", None))
|
||||||
|
kwargs.setdefault("train_dataloader", getattr(self.trainer, "train_batch_generator", None))
|
||||||
|
|
||||||
|
for cb in self.callbacks:
|
||||||
|
getattr(cb, event)(args, state, **kwargs)
|
||||||
|
|
||||||
|
def on_train_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_train_begin", args, state)
|
||||||
|
|
||||||
|
def on_train_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_train_end", args, state)
|
||||||
|
|
||||||
|
def on_epoch_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_epoch_begin", args, state)
|
||||||
|
|
||||||
|
def on_epoch_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_epoch_end", args, state)
|
||||||
|
|
||||||
|
def on_step_begin(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_step_begin", args, state)
|
||||||
|
|
||||||
|
def on_step_end(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_step_end", args, state)
|
||||||
|
|
||||||
|
def on_log(self, args: TrainingArguments, state: TrainerState, logs: dict[str, Any]) -> None:
|
||||||
|
self._call("on_log", args, state, logs=logs)
|
||||||
|
|
||||||
|
def on_save(self, args: TrainingArguments, state: TrainerState) -> None:
|
||||||
|
self._call("on_save", args, state)
|
||||||
@@ -57,7 +57,7 @@ TEXT_MESSAGES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
VIDEO_MESSAGES = [
|
VIDEO_MESSAGES = [
|
||||||
{"role": "user", "content": "<video>What is in this viode?"},
|
{"role": "user", "content": "<video>What is in this video?"},
|
||||||
{"role": "assistant", "content": "A cat."},
|
{"role": "assistant", "content": "A cat."},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -210,6 +210,34 @@ def test_gemma3_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
@pytest.mark.skipif(not is_transformers_version_greater_than("5.6.0"), reason="Requires transformers>=5.6.0")
|
||||||
|
def test_gemma4_plugin():
|
||||||
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-4-31B-it")
|
||||||
|
processor = tokenizer_module["processor"]
|
||||||
|
gemma4_plugin = get_mm_plugin(name="gemma4", image_token="<|image|>", video_token="<|video|>")
|
||||||
|
check_inputs = {"plugin": gemma4_plugin, **tokenizer_module}
|
||||||
|
# validate
|
||||||
|
mm_inputs = gemma4_plugin._get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, processor)
|
||||||
|
num_image_soft_tokens = 256 # when we use default max_soft_tokens=280
|
||||||
|
image_token = getattr(processor, "image_token")
|
||||||
|
boi_token = getattr(processor, "boi_token")
|
||||||
|
eoi_token = getattr(processor, "eoi_token")
|
||||||
|
|
||||||
|
expected_mm_type_ids = [[int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS]
|
||||||
|
check_inputs["expected_mm_messages"] = [
|
||||||
|
{"role": "user", "content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?"},
|
||||||
|
{"role": "assistant", "content": "A cat."},
|
||||||
|
]
|
||||||
|
for key in ("num_soft_tokens_per_image",):
|
||||||
|
mm_inputs.pop(key, None)
|
||||||
|
|
||||||
|
mm_inputs["mm_token_type_ids"] = expected_mm_type_ids
|
||||||
|
check_inputs["expected_mm_inputs"] = mm_inputs
|
||||||
|
check_inputs["expected_no_mm_inputs"] = {"mm_token_type_ids": expected_mm_type_ids}
|
||||||
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
|
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
|
||||||
def test_internvl_plugin():
|
def test_internvl_plugin():
|
||||||
|
|||||||
62
tests_v1/plugins/model_plugins/test_ulysses_cp.py
Normal file
62
tests_v1/plugins/model_plugins/test_ulysses_cp.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||||
|
from llamafactory.v1.config.model_args import ModelArguments
|
||||||
|
from llamafactory.v1.core.model_engine import ModelEngine
|
||||||
|
from llamafactory.v1.plugins.model_plugins.parallelization.sequence_parallel import (
|
||||||
|
SequenceParallelModelPlugin,
|
||||||
|
sequence_parallel_loss,
|
||||||
|
)
|
||||||
|
from llamafactory.v1.utils.env import find_available_port
|
||||||
|
from llamafactory.v1.utils.pytest import dist_env
|
||||||
|
|
||||||
|
|
||||||
|
def _test_sequence_parallel_loss(local_rank: int, world_size: int, master_port: int, cp_size: int, dp_size: int):
|
||||||
|
with dist_env(local_rank, world_size, master_port):
|
||||||
|
model_args = ModelArguments(model="llamafactory/tiny-random-qwen3")
|
||||||
|
|
||||||
|
# Initialize distributed interface with config
|
||||||
|
dist_config = {"cp_mode": "ulysses", "cp_size": cp_size, "dp_size": dp_size}
|
||||||
|
DistributedInterface(dist_config)
|
||||||
|
|
||||||
|
# Now create model engine
|
||||||
|
model_engine = ModelEngine(model_args=model_args)
|
||||||
|
|
||||||
|
# Apply sequence parallel plugin
|
||||||
|
SequenceParallelModelPlugin(dist_config.get("cp_mode", "ulysses"))(model_engine.model, dist_config)
|
||||||
|
|
||||||
|
model_inputs = {
|
||||||
|
"input_ids": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||||
|
"labels": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||||
|
"attention_mask": torch.tensor([[1, 1, 1, 1, 1]]),
|
||||||
|
"position_ids": torch.tensor([[1, 2, 3, 4, 5]]),
|
||||||
|
"loss_weights": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]),
|
||||||
|
}
|
||||||
|
|
||||||
|
loss = sequence_parallel_loss(model_engine.model, model_inputs)
|
||||||
|
assert loss is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cuda", "npu"])
|
||||||
|
@pytest.mark.require_distributed(2)
|
||||||
|
@pytest.mark.parametrize("cp_size, dp_size", [(2, 1)])
|
||||||
|
def test_sequence_parallel_loss(cp_size, dp_size):
|
||||||
|
master_port = find_available_port()
|
||||||
|
world_size = cp_size * dp_size
|
||||||
|
mp.spawn(_test_sequence_parallel_loss, args=(world_size, master_port, cp_size, dp_size), nprocs=world_size)
|
||||||
Reference in New Issue
Block a user