mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-06 12:26:01 +08:00
[model] gemma4 (#10346)
This commit is contained in:
105
.ai/CLAUDE.md
Normal file
105
.ai/CLAUDE.md
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Code style (auto-fix)
|
||||||
|
make style
|
||||||
|
|
||||||
|
# Code quality check (no modifications)
|
||||||
|
make quality
|
||||||
|
|
||||||
|
# Run all tests
|
||||||
|
make test
|
||||||
|
|
||||||
|
# Run a single test file
|
||||||
|
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/path/to/test_file.py
|
||||||
|
|
||||||
|
# Run tests matching a pattern
|
||||||
|
WANDB_DISABLED=true pytest -vv --import-mode=importlib tests/ -k "test_name"
|
||||||
|
|
||||||
|
# License header check
|
||||||
|
make license
|
||||||
|
|
||||||
|
# Build package
|
||||||
|
make build
|
||||||
|
```
|
||||||
|
|
||||||
|
The project uses `uv` as the preferred package manager. Commands automatically use `uv run` / `uvx` if `uv` is available.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
LlamaFactory has two parallel architectures controlled by the `USE_V1` environment variable:
|
||||||
|
|
||||||
|
- **v0 (default):** `api, webui > chat, eval, train > data, model > hparams > extras`
|
||||||
|
- **v1 (experimental, `USE_V1=1`):** `trainers > core > accelerator, plugins, config > utils`
|
||||||
|
|
||||||
|
Most active development happens in v0. The v1 architecture lives in `src/llamafactory/v1/`.
|
||||||
|
|
||||||
|
### Entry Points
|
||||||
|
|
||||||
|
CLI entry point is `llamafactory-cli` / `lmf` → `src/llamafactory/cli.py:main()`, which dispatches to `launcher.py` based on `USE_V1`.
|
||||||
|
|
||||||
|
Available subcommands: `train`, `chat`, `api`, `export`, `webchat`, `webui`, `env`, `version`, `help`.
|
||||||
|
|
||||||
|
### Training Flow (v0)
|
||||||
|
|
||||||
|
```
|
||||||
|
run_exp() [tuner.py]
|
||||||
|
→ read_args() → parse YAML/JSON config
|
||||||
|
→ get_train_args() → produces typed argument dataclasses
|
||||||
|
→ routes to: run_sft / run_dpo / run_ppo / run_rm / run_pt / run_kto
|
||||||
|
→ optional: export_model()
|
||||||
|
```
|
||||||
|
|
||||||
|
Training is invoked with a YAML config: `llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml`
|
||||||
|
|
||||||
|
### Configuration System
|
||||||
|
|
||||||
|
All training parameters are YAML/JSON config files. Argument parsing in `src/llamafactory/hparams/parser.py` produces four typed dataclasses:
|
||||||
|
- `ModelArguments` — model/tokenizer selection, quantization
|
||||||
|
- `DataArguments` — datasets, templates, preprocessing
|
||||||
|
- `FinetuningArguments` — LoRA rank/target, training method (sft/dpo/ppo/rm/pt/kto)
|
||||||
|
- `TrainingArguments` — extends HuggingFace's `TrainingArguments`
|
||||||
|
|
||||||
|
### Key Modules
|
||||||
|
|
||||||
|
| Module | Purpose |
|
||||||
|
|--------|---------|
|
||||||
|
| `src/llamafactory/model/loader.py` | Loads model + tokenizer; applies quantization, LoRA, patches |
|
||||||
|
| `src/llamafactory/model/patcher.py` | Model-specific compatibility patches |
|
||||||
|
| `src/llamafactory/data/template.py` | Prompt templates; `TEMPLATES` dict maps model family → format |
|
||||||
|
| `src/llamafactory/data/mm_plugin.py` | Multi-modal (image/video/audio) data handling |
|
||||||
|
| `src/llamafactory/data/processor/` | Per-stage data processors (supervised, pairwise, pretrain, etc.) |
|
||||||
|
| `src/llamafactory/train/sft/` | SFT trainer; other stages follow same structure |
|
||||||
|
| `src/llamafactory/chat/` | Inference engines: `hf_engine`, `vllm_engine`, `sglang_engine`, `kt_engine` |
|
||||||
|
| `src/llamafactory/extras/constants.py` | Enums and constants used across the project |
|
||||||
|
|
||||||
|
### Adding Support for a New Model
|
||||||
|
|
||||||
|
1. Add a prompt template to `src/llamafactory/data/template.py` in the `TEMPLATES` dict
|
||||||
|
2. Add any necessary model patches in `src/llamafactory/model/patcher.py`
|
||||||
|
3. Add multi-modal support in `src/llamafactory/data/mm_plugin.py` if needed
|
||||||
|
|
||||||
|
### Distributed Training
|
||||||
|
|
||||||
|
Multi-GPU automatically uses `torchrun`. Additional backends:
|
||||||
|
- **Ray:** Optional Ray cluster support
|
||||||
|
- **HyperParallel FSDP2:** `src/llamafactory/train/hyper_parallel/`
|
||||||
|
- **Megatron-core:** `src/llamafactory/train/mca/`
|
||||||
|
|
||||||
|
### Testing
|
||||||
|
|
||||||
|
- `tests/` — v0 tests; `tests_v1/` — v1 tests
|
||||||
|
- Most training tests require GPU hardware
|
||||||
|
- pytest markers: `@pytest.mark.slow`, `@pytest.mark.runs_on(['cuda'])`
|
||||||
|
- Always set `WANDB_DISABLED=true` when running tests
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
- Ruff for linting and formatting (line length 119, Google-style docstrings)
|
||||||
|
- Python 3.11+ syntax
|
||||||
|
- Double quotes for strings
|
||||||
|
- All new files must include Apache 2.0 license header (checked by `make license`)
|
||||||
@@ -607,6 +607,194 @@ class Gemma3nPlugin(Gemma3Plugin):
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Gemma4Plugin(BasePlugin):
|
||||||
|
r"""Plugin for the Gemma4 multimodal model."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
|
r"""Regularize videos, also tracking per-video FPS and frame indices for timestamp generation."""
|
||||||
|
results, fps_per_video, durations, frames_indices = [], [], [], []
|
||||||
|
for video in videos:
|
||||||
|
frames: list[ImageObject] = []
|
||||||
|
if _check_video_is_nested_images(video):
|
||||||
|
frames = video
|
||||||
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
frames_indices.append(list(range(len(frames))))
|
||||||
|
else:
|
||||||
|
container = av.open(video, "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
|
original_fps = float(video_stream.average_rate)
|
||||||
|
# for correctly calculate timestamps
|
||||||
|
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices])
|
||||||
|
container.seek(0)
|
||||||
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
|
if frame_idx in sample_indices:
|
||||||
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
if video_stream.duration is None:
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
else:
|
||||||
|
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
|
results.append(frames)
|
||||||
|
|
||||||
|
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: "MMProcessor",
|
||||||
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||||
|
image_processor = getattr(processor, "image_processor", None)
|
||||||
|
video_processor = getattr(processor, "video_processor", None)
|
||||||
|
feature_extractor = getattr(processor, "feature_extractor", None)
|
||||||
|
mm_inputs = {}
|
||||||
|
|
||||||
|
if len(images) != 0:
|
||||||
|
regularized = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||||
|
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||||
|
)["images"]
|
||||||
|
mm_inputs.update(image_processor(regularized, return_tensors="pt"))
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
video_data = self._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||||
|
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||||
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
|
)
|
||||||
|
video_metadata = [
|
||||||
|
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||||
|
for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"])
|
||||||
|
]
|
||||||
|
mm_inputs.update(
|
||||||
|
video_processor(
|
||||||
|
videos=video_data["videos"],
|
||||||
|
video_metadata=video_metadata,
|
||||||
|
return_tensors="pt",
|
||||||
|
return_metadata=True,
|
||||||
|
do_sample_frames=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(audios) != 0: # only for gemma4n
|
||||||
|
audios = self._regularize_audios(
|
||||||
|
audios,
|
||||||
|
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||||
|
)["audios"]
|
||||||
|
|
||||||
|
mm_inputs.update(
|
||||||
|
feature_extractor(
|
||||||
|
audios,
|
||||||
|
padding="max_length",
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
self._validate_messages(messages, images, videos, audios)
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
|
||||||
|
boi_token: str = getattr(processor, "boi_token")
|
||||||
|
eoi_token: str = getattr(processor, "eoi_token")
|
||||||
|
boa_token: str = getattr(processor, "boa_token")
|
||||||
|
eoa_token: str = getattr(processor, "eoa_token")
|
||||||
|
image_token: str = getattr(processor, "image_token")
|
||||||
|
video_token: str = getattr(processor, "video_token")
|
||||||
|
audio_token: str = getattr(processor, "audio_token")
|
||||||
|
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
num_image_soft_tokens: list[int] = list(
|
||||||
|
mm_inputs.get("num_soft_tokens_per_image", [getattr(processor, "image_seq_length", 256)] * len(images))
|
||||||
|
)
|
||||||
|
num_video_soft_tokens: list[int] = list(mm_inputs.get("num_soft_tokens_per_video", [1] * len(videos)))
|
||||||
|
video_metadata = mm_inputs.get("video_metadata", [])
|
||||||
|
else:
|
||||||
|
num_image_soft_tokens = [1] * len(images)
|
||||||
|
num_video_soft_tokens = [1] * len(videos)
|
||||||
|
video_metadata = [None] * len(videos)
|
||||||
|
|
||||||
|
audio_iter = iter(audios)
|
||||||
|
image_iter = iter(num_image_soft_tokens)
|
||||||
|
video_iter = iter(zip(num_video_soft_tokens, video_metadata))
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
n = next(image_iter)
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, f"{boi_token}{image_token * n}{eoi_token}", 1)
|
||||||
|
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
num_soft_tokens_per_frame, metadata = next(video_iter)
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
|
||||||
|
frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs]
|
||||||
|
video_str = " ".join(frame_strs)
|
||||||
|
else:
|
||||||
|
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, video_str, 1)
|
||||||
|
|
||||||
|
while AUDIO_PLACEHOLDER in content:
|
||||||
|
current_audio = next(audio_iter)
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate)
|
||||||
|
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
|
||||||
|
else:
|
||||||
|
audio_str = f"{boa_token}{audio_token}{eoa_token}"
|
||||||
|
|
||||||
|
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
imglens: list[int],
|
||||||
|
vidlens: list[int],
|
||||||
|
audlens: list[int],
|
||||||
|
batch_ids: list[list[int]],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
# Pop metadata keys that must not be passed to the model.
|
||||||
|
for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata",
|
||||||
|
"_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"):
|
||||||
|
mm_inputs.pop(key, None)
|
||||||
|
|
||||||
|
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InternVLPlugin(BasePlugin):
|
class InternVLPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@@ -1505,7 +1693,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
original_fps = float(video_stream.average_rate)
|
original_fps = float(video_stream.average_rate)
|
||||||
# for qwen3vl video timestamp calculation
|
# for qwen3vl video timestamp calculation
|
||||||
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
|
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
|
||||||
@@ -1642,7 +1830,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||||
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
|
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(
|
||||||
@@ -1683,7 +1871,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||||
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
||||||
video_metadata = mm_inputs.get("video_metadata", {})
|
video_metadata = mm_inputs.get("video_metadata", [])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
image_grid_thw = [None] * len(images)
|
image_grid_thw = [None] * len(images)
|
||||||
@@ -2206,8 +2394,9 @@ PLUGINS = {
|
|||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"ernie_vl": ErnieVLPlugin,
|
"ernie_vl": ErnieVLPlugin,
|
||||||
"gemma3": Gemma3Plugin,
|
"gemma3": Gemma3Plugin,
|
||||||
"glm4v": GLM4VPlugin,
|
|
||||||
"gemma3n": Gemma3nPlugin,
|
"gemma3n": Gemma3nPlugin,
|
||||||
|
"gemma4": Gemma4Plugin,
|
||||||
|
"glm4v": GLM4VPlugin,
|
||||||
"intern_vl": InternVLPlugin,
|
"intern_vl": InternVLPlugin,
|
||||||
"kimi_vl": KimiVLPlugin,
|
"kimi_vl": KimiVLPlugin,
|
||||||
"llama4": Llama4Plugin,
|
"llama4": Llama4Plugin,
|
||||||
|
|||||||
@@ -997,6 +997,55 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="gemma4",
|
||||||
|
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
|
||||||
|
), # seem not consistent with the chattemplate
|
||||||
|
format_tools=ToolFormatter(tool_format="gemma4"),
|
||||||
|
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<turn|>"],
|
||||||
|
default_system="You are a helpful assistant.", # important for thinking
|
||||||
|
thought_words=("<|channel>thought\n", "<channel|>"),
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(
|
||||||
|
"gemma4",
|
||||||
|
image_token="<|image|>",
|
||||||
|
video_token="<|video|>",
|
||||||
|
),
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="gemma4n",
|
||||||
|
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="gemma4"),
|
||||||
|
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<turn|>"],
|
||||||
|
default_system="You are a helpful assistant.", # important for thinking
|
||||||
|
thought_words=("<|channel>thought\n", "<channel|>"),
|
||||||
|
replace_eos=True,
|
||||||
|
mm_plugin=get_mm_plugin(
|
||||||
|
"gemma4",
|
||||||
|
image_token="<|image|>",
|
||||||
|
video_token="<|video|>",
|
||||||
|
audio_token="<|audio|>",
|
||||||
|
),
|
||||||
|
template_class=ReasoningTemplate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="glm4",
|
name="glm4",
|
||||||
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
||||||
|
|||||||
@@ -209,6 +209,164 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
class Gemma4ToolUtils(ToolUtils):
|
||||||
|
r"""Gemma-4 tool using template."""
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_formatter(tools: list[dict[str, Any]]) -> str:
|
||||||
|
def _format_parameters(properties: dict[str, Any]) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
for name, schema in properties.items():
|
||||||
|
item_parts: list[str] = []
|
||||||
|
if schema.get("description"):
|
||||||
|
item_parts.append(f'description:<|"|>{schema["description"]}<|"|>')
|
||||||
|
if schema.get("type"):
|
||||||
|
item_parts.append(f'type:<|"|>{str(schema["type"]).upper()}<|"|>')
|
||||||
|
parts.append(f"{name}:{{{','.join(item_parts)}}}")
|
||||||
|
|
||||||
|
return ",".join(parts)
|
||||||
|
|
||||||
|
declarations: list[str] = []
|
||||||
|
for tool in tools:
|
||||||
|
function_data = tool.get("function", tool) if tool.get("type") == "function" else tool
|
||||||
|
declaration = (
|
||||||
|
f"declaration:{function_data['name']}"
|
||||||
|
+ "{"
|
||||||
|
+ f'description:<|"|>{function_data.get("description", "")}<|"|>'
|
||||||
|
)
|
||||||
|
|
||||||
|
params = function_data.get("parameters")
|
||||||
|
if params:
|
||||||
|
param_parts: list[str] = []
|
||||||
|
if params.get("properties"):
|
||||||
|
param_parts.append(f"properties:{{{_format_parameters(params['properties'])}}}")
|
||||||
|
|
||||||
|
if params.get("required"):
|
||||||
|
required_text = ",".join(f'<|"|>{item}<|"|>' for item in params["required"])
|
||||||
|
param_parts.append(f"required:[{required_text}]")
|
||||||
|
|
||||||
|
if params.get("type"):
|
||||||
|
param_parts.append(f'type:<|"|>{str(params["type"]).upper()}<|"|>')
|
||||||
|
|
||||||
|
declaration += f",parameters:{{{','.join(param_parts)}}}"
|
||||||
|
|
||||||
|
response_declaration = function_data.get("response")
|
||||||
|
if response_declaration:
|
||||||
|
response_parts: list[str] = []
|
||||||
|
if response_declaration.get("description"):
|
||||||
|
response_parts.append(f'description:<|"|>{response_declaration["description"]}<|"|>')
|
||||||
|
|
||||||
|
response_type = str(response_declaration.get("type", "")).upper()
|
||||||
|
|
||||||
|
if response_type == "OBJECT":
|
||||||
|
response_parts.append(f'type:<|"|>{response_type}<|"|>')
|
||||||
|
|
||||||
|
declaration += f",response:{{{','.join(response_parts)}}}"
|
||||||
|
|
||||||
|
declarations.append(declaration + "}")
|
||||||
|
|
||||||
|
return "\n".join(declarations)
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
|
||||||
|
regex = re.compile(r"<\|tool_call\>call:([^{\s]+)\{(.*?)\}<tool_call\|>", re.DOTALL)
|
||||||
|
matches = re.findall(regex, content)
|
||||||
|
if not matches:
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _parse_arguments(arg_text: str) -> Any:
|
||||||
|
text = arg_text.strip()
|
||||||
|
if not text:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# `function_formatter` writes dict arguments as `k:v,...` inside `{...}`.
|
||||||
|
# The extractor captures only the inner text, so re-wrap it to parse as JSON object.
|
||||||
|
object_like_text = "{" + text + "}"
|
||||||
|
# Convert Gemma string markers (<|"|>value<|"|>) to valid JSON strings.
|
||||||
|
normalized = re.sub(
|
||||||
|
r"<\|\"\|\>(.*?)<\|\"\|\>",
|
||||||
|
lambda m: json.dumps(m.group(1), ensure_ascii=False),
|
||||||
|
object_like_text,
|
||||||
|
flags=re.DOTALL,
|
||||||
|
)
|
||||||
|
# Quote unquoted object keys so the payload can be parsed by json.loads.
|
||||||
|
normalized = re.sub(r'(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)', r'\1"\2"\3', normalized)
|
||||||
|
try:
|
||||||
|
return json.loads(normalized)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return text
|
||||||
|
|
||||||
|
results: list[FunctionCall] = []
|
||||||
|
for name, arg_block in matches:
|
||||||
|
parsed_arguments = _parse_arguments(arg_block)
|
||||||
|
if isinstance(parsed_arguments, str):
|
||||||
|
arguments = parsed_arguments
|
||||||
|
else:
|
||||||
|
arguments = json.dumps(parsed_arguments, ensure_ascii=False)
|
||||||
|
results.append(FunctionCall(name.strip(), arguments))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@override
|
||||||
|
@staticmethod
|
||||||
|
def function_formatter(functions: list["FunctionCall"]) -> str:
|
||||||
|
def _format_argument(argument: Any, escape_keys: bool = True) -> str:
|
||||||
|
if isinstance(argument, str):
|
||||||
|
return f'<|"|>{argument}<|"|>'
|
||||||
|
|
||||||
|
if isinstance(argument, bool):
|
||||||
|
return "true" if argument else "false"
|
||||||
|
|
||||||
|
if isinstance(argument, dict):
|
||||||
|
items: list[str] = []
|
||||||
|
for key in sorted(argument.keys()):
|
||||||
|
formatted_key = f'<|"|>{key}<|"|>' if escape_keys else str(key)
|
||||||
|
formatted_value = _format_argument(argument[key], escape_keys=escape_keys)
|
||||||
|
items.append(f"{formatted_key}:{formatted_value}")
|
||||||
|
return "{" + ",".join(items) + "}"
|
||||||
|
|
||||||
|
if isinstance(argument, (list, tuple)):
|
||||||
|
return "[" + ",".join(_format_argument(item, escape_keys=escape_keys) for item in argument) + "]"
|
||||||
|
|
||||||
|
if argument is None:
|
||||||
|
return "null"
|
||||||
|
|
||||||
|
return str(argument)
|
||||||
|
|
||||||
|
function_texts: list[str] = []
|
||||||
|
for function in functions:
|
||||||
|
name = function.name
|
||||||
|
raw_arguments = function.arguments
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_arguments = json.loads(raw_arguments)
|
||||||
|
except (TypeError, json.JSONDecodeError):
|
||||||
|
parsed_arguments = raw_arguments
|
||||||
|
|
||||||
|
call_text = f"<|tool_call>call:{name}" + "{"
|
||||||
|
if isinstance(parsed_arguments, dict):
|
||||||
|
args_text = []
|
||||||
|
for key in sorted(parsed_arguments.keys()):
|
||||||
|
value_text = _format_argument(parsed_arguments[key], escape_keys=False)
|
||||||
|
args_text.append(f"{key}:{value_text}")
|
||||||
|
|
||||||
|
call_text += ",".join(args_text)
|
||||||
|
elif isinstance(parsed_arguments, str):
|
||||||
|
call_text += parsed_arguments
|
||||||
|
else:
|
||||||
|
call_text += _format_argument(parsed_arguments, escape_keys=False)
|
||||||
|
|
||||||
|
call_text += "}<tool_call|>"
|
||||||
|
function_texts.append(call_text)
|
||||||
|
|
||||||
|
return "".join(function_texts)
|
||||||
|
|
||||||
class GLM4ToolUtils(ToolUtils):
|
class GLM4ToolUtils(ToolUtils):
|
||||||
r"""GLM-4 tool using template."""
|
r"""GLM-4 tool using template."""
|
||||||
@@ -723,6 +881,7 @@ class LFM2ToolUtils(ToolUtils):
|
|||||||
|
|
||||||
TOOLS = {
|
TOOLS = {
|
||||||
"default": DefaultToolUtils(),
|
"default": DefaultToolUtils(),
|
||||||
|
"gemma4": Gemma4ToolUtils(),
|
||||||
"glm4": GLM4ToolUtils(),
|
"glm4": GLM4ToolUtils(),
|
||||||
"llama3": Llama3ToolUtils(),
|
"llama3": Llama3ToolUtils(),
|
||||||
"lfm2": LFM2ToolUtils(),
|
"lfm2": LFM2ToolUtils(),
|
||||||
|
|||||||
@@ -865,6 +865,34 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Gemma-4-26B-A4B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-4-26B-A4B-it",
|
||||||
|
},
|
||||||
|
"Gemma-4-31B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-4-31B-it",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="gemma4",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Gemma-4-E2B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-4-E2B-it",
|
||||||
|
},
|
||||||
|
"Gemma-4-E4B-Thinking": {
|
||||||
|
DownloadSource.DEFAULT: "google/gemma-4-E4B-it",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="gemma4n",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"GLM-4-9B": {
|
"GLM-4-9B": {
|
||||||
|
|||||||
@@ -219,6 +219,13 @@ _register_composite_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="gemma4",
|
||||||
|
vision_model_keys=["vision_tower", "audio_tower"],
|
||||||
|
lora_conflict_keys=["per_layer_projection_norm"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# copied from qwen2vl
|
# copied from qwen2vl
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="glm4v",
|
model_type="glm4v",
|
||||||
|
|||||||
@@ -48,7 +48,10 @@ def run_sft(
|
|||||||
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
||||||
)
|
)
|
||||||
|
|
||||||
from hyper_parallel.integration.llamafactory import HyperParallelArguments, HyperParallelTrainer # pylint: disable=C0415
|
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
|
||||||
|
HyperParallelArguments,
|
||||||
|
HyperParallelTrainer,
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer_module = load_tokenizer(model_args)
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
tokenizer = tokenizer_module["tokenizer"]
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
@@ -128,9 +131,10 @@ def run_sft(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if finetuning_args.use_badam:
|
if finetuning_args.use_badam:
|
||||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
|
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
|
|
||||||
|
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore[import]
|
||||||
|
|
||||||
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
|
trainer.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, trainer.accelerator)
|
||||||
trainer.add_callback(BAdamCallback)
|
trainer.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ TEXT_MESSAGES = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
VIDEO_MESSAGES = [
|
VIDEO_MESSAGES = [
|
||||||
{"role": "user", "content": "<video>What is in this viode?"},
|
{"role": "user", "content": "<video>What is in this video?"},
|
||||||
{"role": "assistant", "content": "A cat."},
|
{"role": "assistant", "content": "A cat."},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -210,6 +210,34 @@ def test_gemma3_plugin():
|
|||||||
_check_plugin(**check_inputs)
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
|
@pytest.mark.skipif(not is_transformers_version_greater_than("5.6.0"), reason="Requires transformers>=5.6.0")
|
||||||
|
def test_gemma4_plugin():
|
||||||
|
tokenizer_module = _load_tokenizer_module(model_name_or_path="google/gemma-4-31B-it")
|
||||||
|
processor = tokenizer_module["processor"]
|
||||||
|
gemma4_plugin = get_mm_plugin(name="gemma4", image_token="<|image|>", video_token="<|video|>")
|
||||||
|
check_inputs = {"plugin": gemma4_plugin, **tokenizer_module}
|
||||||
|
# validate
|
||||||
|
mm_inputs = gemma4_plugin._get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, processor)
|
||||||
|
num_image_soft_tokens = 256 # when we use default max_soft_tokens=280
|
||||||
|
image_token = getattr(processor, "image_token")
|
||||||
|
boi_token = getattr(processor, "boi_token")
|
||||||
|
eoi_token = getattr(processor, "eoi_token")
|
||||||
|
|
||||||
|
expected_mm_type_ids = [[int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS]
|
||||||
|
check_inputs["expected_mm_messages"] = [
|
||||||
|
{"role": "user", "content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?"},
|
||||||
|
{"role": "assistant", "content": "A cat."},
|
||||||
|
]
|
||||||
|
for key in ("num_soft_tokens_per_image",):
|
||||||
|
mm_inputs.pop(key, None)
|
||||||
|
|
||||||
|
mm_inputs["mm_token_type_ids"] = expected_mm_type_ids
|
||||||
|
check_inputs["expected_mm_inputs"] = mm_inputs
|
||||||
|
check_inputs["expected_no_mm_inputs"] = {"mm_token_type_ids": expected_mm_type_ids}
|
||||||
|
_check_plugin(**check_inputs)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
|
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
|
||||||
def test_internvl_plugin():
|
def test_internvl_plugin():
|
||||||
|
|||||||
Reference in New Issue
Block a user