mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 03:02:51 +08:00
[model] add llama4 (#7611)
This commit is contained in:
parent
61b24c3827
commit
6c200fd218
@ -1,6 +1,4 @@
|
||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' and sys_platform != 'darwin'
|
||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' and sys_platform != 'darwin'
|
||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;sys_platform == 'darwin'
|
||||
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||
datasets>=2.16.0,<=3.4.1
|
||||
accelerate>=0.34.0,<=1.5.2
|
||||
peft>=0.14.0,<=0.15.0
|
||||
|
39
scripts/convert_ckpt/tiny_llama4.py
Normal file
39
scripts/convert_ckpt/tiny_llama4.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
vision_config = Llama4VisionConfig(
|
||||
hidden_size=1408,
|
||||
image_size=336,
|
||||
intermediate_size=5632,
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=4,
|
||||
vision_output_dim=4096,
|
||||
)
|
||||
text_config = Llama4TextConfig(
|
||||
hidden_size=512,
|
||||
intermediate_size=1024,
|
||||
intermediate_size_mlp=1024,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=2,
|
||||
head_dim=512 // 8,
|
||||
num_local_experts=2,
|
||||
)
|
||||
config = Llama4Config(vision_config=vision_config, text_config=text_config)
|
||||
model = Llama4ForConditionalGeneration._from_config(config)
|
||||
model.save_pretrained("tiny-llama4")
|
@ -19,7 +19,7 @@ Level:
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||
datasets>=2.16.0,<=3.4.1
|
||||
accelerate>=0.34.0,<=1.5.2
|
||||
peft>=0.14.0,<=0.15.0
|
||||
|
@ -466,6 +466,73 @@ class Gemma3Plugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama4Plugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
if "pixel_values" in mm_inputs:
|
||||
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
|
||||
num_patches_per_chunk = int(
|
||||
(image_height // processor.patch_size)
|
||||
* (image_width // processor.patch_size)
|
||||
// processor.downsample_ratio
|
||||
)
|
||||
aspect_ratios = mm_inputs.pop("aspect_ratios")
|
||||
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
placeholder_count = content.count(IMAGE_PLACEHOLDER)
|
||||
if self.expand_mm_tokens:
|
||||
prompt_splits = content.split(IMAGE_PLACEHOLDER)
|
||||
new_content = []
|
||||
for local_image_index, split_part in enumerate(prompt_splits):
|
||||
new_content.append(split_part)
|
||||
if local_image_index < placeholder_count:
|
||||
tokens_for_this_image = processor._prompt_split_image(
|
||||
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
||||
)
|
||||
num_image_tokens += 1
|
||||
new_content.append(tokens_for_this_image)
|
||||
|
||||
content = "".join(new_content)
|
||||
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("aspect_ratios", None)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
@ -1485,6 +1552,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"llama4": Llama4Plugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
|
@ -968,6 +968,26 @@ register_template(
|
||||
)
|
||||
|
||||
|
||||
register_template(
|
||||
name="llama4",
|
||||
format_user=StringFormatter(
|
||||
slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
|
||||
format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
"<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="llama3"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
stop_words=["<|eot|>", "<|eom|>"],
|
||||
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
|
||||
)
|
||||
|
||||
|
||||
# copied from llama3 template
|
||||
register_template(
|
||||
name="mllama",
|
||||
|
@ -1111,6 +1111,30 @@ register_model_group(
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"Llama-4-Scout-17B-16E": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E",
|
||||
},
|
||||
"Llama-4-Scout-17B-16E-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E-Instruct",
|
||||
},
|
||||
"Llama-4-Maverick-17B-128E": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E",
|
||||
},
|
||||
"Llama-4-Maverick-17B-128E-Instruct": {
|
||||
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
||||
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E-Instruct",
|
||||
},
|
||||
},
|
||||
template="llama4",
|
||||
multimodal=True,
|
||||
)
|
||||
|
||||
|
||||
register_model_group(
|
||||
models={
|
||||
"LLaVA-1.5-7B-Chat": {
|
||||
|
@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
||||
|
||||
def check_dependencies() -> None:
|
||||
r"""Check the version of the required packages."""
|
||||
check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||
check_version("datasets>=2.16.0,<=3.4.1")
|
||||
check_version("accelerate>=0.34.0,<=1.5.2")
|
||||
check_version("peft>=0.14.0,<=0.15.0")
|
||||
|
@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
|
||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: torch.nn.Module = func.__self__
|
||||
if isinstance(func, partial):
|
||||
module: torch.nn.Module = func.func.__self__
|
||||
else:
|
||||
module: torch.nn.Module = func.__self__
|
||||
|
||||
has_grad = False
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
|
@ -203,6 +203,12 @@ _register_composite_model(
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="llama4",
|
||||
vision_model_keys=["vision_model"],
|
||||
)
|
||||
|
||||
|
||||
_register_composite_model(
|
||||
model_type="llava",
|
||||
)
|
||||
|
@ -243,7 +243,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||
mini_batch = {
|
||||
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
|
||||
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
|
||||
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
|
||||
}
|
||||
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
|
||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
|
||||
|
@ -1 +1,2 @@
|
||||
0.9.3.100
|
||||
# change if test fails
|
||||
0.9.3.101
|
||||
|
Loading…
x
Reference in New Issue
Block a user