[model] add llama4 (#7611)

This commit is contained in:
hoshi-hiyouga 2025-04-06 13:42:31 +08:00 committed by GitHub
parent 61b24c3827
commit 6c200fd218
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 167 additions and 8 deletions

View File

@ -1,6 +1,4 @@
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' and sys_platform != 'darwin'
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' and sys_platform != 'darwin'
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;sys_platform == 'darwin'
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.4.1
accelerate>=0.34.0,<=1.5.2
peft>=0.14.0,<=0.15.0

View File

@ -0,0 +1,39 @@
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
if __name__ == "__main__":
vision_config = Llama4VisionConfig(
hidden_size=1408,
image_size=336,
intermediate_size=5632,
num_attention_heads=16,
num_hidden_layers=4,
vision_output_dim=4096,
)
text_config = Llama4TextConfig(
hidden_size=512,
intermediate_size=1024,
intermediate_size_mlp=1024,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=2,
head_dim=512 // 8,
num_local_experts=2,
)
config = Llama4Config(vision_config=vision_config, text_config=text_config)
model = Llama4ForConditionalGeneration._from_config(config)
model.save_pretrained("tiny-llama4")

View File

@ -19,7 +19,7 @@ Level:
Dependency graph:
main:
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.4.1
accelerate>=0.34.0,<=1.5.2
peft>=0.14.0,<=0.15.0

View File

@ -466,6 +466,73 @@ class Gemma3Plugin(BasePlugin):
return mm_inputs
@dataclass
class Llama4Plugin(BasePlugin):
@override
def process_messages(
self,
messages: list[dict[str, str]],
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
processor: Optional["MMProcessor"],
) -> list[dict[str, str]]:
self._validate_input(processor, images, videos, audios)
if self.expand_mm_tokens:
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
if "pixel_values" in mm_inputs:
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
num_patches_per_chunk = int(
(image_height // processor.patch_size)
* (image_width // processor.patch_size)
// processor.downsample_ratio
)
aspect_ratios = mm_inputs.pop("aspect_ratios")
num_image_tokens = 0
messages = deepcopy(messages)
for message in messages:
content = message["content"]
placeholder_count = content.count(IMAGE_PLACEHOLDER)
if self.expand_mm_tokens:
prompt_splits = content.split(IMAGE_PLACEHOLDER)
new_content = []
for local_image_index, split_part in enumerate(prompt_splits):
new_content.append(split_part)
if local_image_index < placeholder_count:
tokens_for_this_image = processor._prompt_split_image(
aspect_ratios[num_image_tokens], num_patches_per_chunk
)
num_image_tokens += 1
new_content.append(tokens_for_this_image)
content = "".join(new_content)
message["content"] = content
if len(images) != num_image_tokens:
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
return messages
@override
def get_mm_inputs(
self,
images: list["ImageInput"],
videos: list["VideoInput"],
audios: list["AudioInput"],
imglens: list[int],
vidlens: list[int],
audlens: list[int],
batch_ids: list[list[int]],
processor: Optional["MMProcessor"],
) -> dict[str, Union[list[int], "torch.Tensor"]]:
self._validate_input(processor, images, videos, audios)
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
mm_inputs.pop("aspect_ratios", None)
return mm_inputs
@dataclass
class LlavaPlugin(BasePlugin):
@override
@ -1485,6 +1552,7 @@ class VideoLlavaPlugin(BasePlugin):
PLUGINS = {
"base": BasePlugin,
"gemma3": Gemma3Plugin,
"llama4": Llama4Plugin,
"llava": LlavaPlugin,
"llava_next": LlavaNextPlugin,
"llava_next_video": LlavaNextVideoPlugin,

View File

@ -968,6 +968,26 @@ register_template(
)
register_template(
name="llama4",
format_user=StringFormatter(
slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
),
format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
format_observation=StringFormatter(
slots=[
"<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
]
),
format_tools=ToolFormatter(tool_format="llama3"),
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
stop_words=["<|eot|>", "<|eom|>"],
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
)
# copied from llama3 template
register_template(
name="mllama",

View File

@ -1111,6 +1111,30 @@ register_model_group(
)
register_model_group(
models={
"Llama-4-Scout-17B-16E": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E",
},
"Llama-4-Scout-17B-16E-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E-Instruct",
},
"Llama-4-Maverick-17B-128E": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E",
},
"Llama-4-Maverick-17B-128E-Instruct": {
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E-Instruct",
},
},
template="llama4",
multimodal=True,
)
register_model_group(
models={
"LLaVA-1.5-7B-Chat": {

View File

@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
def check_dependencies() -> None:
r"""Check the version of the required packages."""
check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
check_version("datasets>=2.16.0,<=3.4.1")
check_version("accelerate>=0.34.0,<=1.5.2")
check_version("peft>=0.14.0,<=0.15.0")

View File

@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: torch.nn.Module = func.__self__
if isinstance(func, partial):
module: torch.nn.Module = func.func.__self__
else:
module: torch.nn.Module = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):

View File

@ -203,6 +203,12 @@ _register_composite_model(
)
_register_composite_model(
model_type="llama4",
vision_model_keys=["vision_model"],
)
_register_composite_model(
model_type="llava",
)

View File

@ -243,7 +243,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
mini_batch = {
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
}
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)

View File

@ -1 +1,2 @@
0.9.3.100
# change if test fails
0.9.3.101