mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[model] add llama4 (#7611)
This commit is contained in:
parent
61b24c3827
commit
6c200fd218
@ -1,6 +1,4 @@
|
|||||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' and sys_platform != 'darwin'
|
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' and sys_platform != 'darwin'
|
|
||||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;sys_platform == 'darwin'
|
|
||||||
datasets>=2.16.0,<=3.4.1
|
datasets>=2.16.0,<=3.4.1
|
||||||
accelerate>=0.34.0,<=1.5.2
|
accelerate>=0.34.0,<=1.5.2
|
||||||
peft>=0.14.0,<=0.15.0
|
peft>=0.14.0,<=0.15.0
|
||||||
|
39
scripts/convert_ckpt/tiny_llama4.py
Normal file
39
scripts/convert_ckpt/tiny_llama4.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
# Copyright 2025 the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
vision_config = Llama4VisionConfig(
|
||||||
|
hidden_size=1408,
|
||||||
|
image_size=336,
|
||||||
|
intermediate_size=5632,
|
||||||
|
num_attention_heads=16,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
vision_output_dim=4096,
|
||||||
|
)
|
||||||
|
text_config = Llama4TextConfig(
|
||||||
|
hidden_size=512,
|
||||||
|
intermediate_size=1024,
|
||||||
|
intermediate_size_mlp=1024,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=8,
|
||||||
|
num_key_value_heads=2,
|
||||||
|
head_dim=512 // 8,
|
||||||
|
num_local_experts=2,
|
||||||
|
)
|
||||||
|
config = Llama4Config(vision_config=vision_config, text_config=text_config)
|
||||||
|
model = Llama4ForConditionalGeneration._from_config(config)
|
||||||
|
model.save_pretrained("tiny-llama4")
|
@ -19,7 +19,7 @@ Level:
|
|||||||
|
|
||||||
Dependency graph:
|
Dependency graph:
|
||||||
main:
|
main:
|
||||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0
|
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
|
||||||
datasets>=2.16.0,<=3.4.1
|
datasets>=2.16.0,<=3.4.1
|
||||||
accelerate>=0.34.0,<=1.5.2
|
accelerate>=0.34.0,<=1.5.2
|
||||||
peft>=0.14.0,<=0.15.0
|
peft>=0.14.0,<=0.15.0
|
||||||
|
@ -466,6 +466,73 @@ class Gemma3Plugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Llama4Plugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
if "pixel_values" in mm_inputs:
|
||||||
|
image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
|
||||||
|
num_patches_per_chunk = int(
|
||||||
|
(image_height // processor.patch_size)
|
||||||
|
* (image_width // processor.patch_size)
|
||||||
|
// processor.downsample_ratio
|
||||||
|
)
|
||||||
|
aspect_ratios = mm_inputs.pop("aspect_ratios")
|
||||||
|
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
placeholder_count = content.count(IMAGE_PLACEHOLDER)
|
||||||
|
if self.expand_mm_tokens:
|
||||||
|
prompt_splits = content.split(IMAGE_PLACEHOLDER)
|
||||||
|
new_content = []
|
||||||
|
for local_image_index, split_part in enumerate(prompt_splits):
|
||||||
|
new_content.append(split_part)
|
||||||
|
if local_image_index < placeholder_count:
|
||||||
|
tokens_for_this_image = processor._prompt_split_image(
|
||||||
|
aspect_ratios[num_image_tokens], num_patches_per_chunk
|
||||||
|
)
|
||||||
|
num_image_tokens += 1
|
||||||
|
new_content.append(tokens_for_this_image)
|
||||||
|
|
||||||
|
content = "".join(new_content)
|
||||||
|
|
||||||
|
message["content"] = content
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
imglens: list[int],
|
||||||
|
vidlens: list[int],
|
||||||
|
audlens: list[int],
|
||||||
|
batch_ids: list[list[int]],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
mm_inputs.pop("aspect_ratios", None)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LlavaPlugin(BasePlugin):
|
class LlavaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@ -1485,6 +1552,7 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"gemma3": Gemma3Plugin,
|
"gemma3": Gemma3Plugin,
|
||||||
|
"llama4": Llama4Plugin,
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
"llava_next_video": LlavaNextVideoPlugin,
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
|
@ -968,6 +968,26 @@ register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="llama4",
|
||||||
|
format_user=StringFormatter(
|
||||||
|
slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
|
||||||
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
|
||||||
|
format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=[
|
||||||
|
"<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
|
||||||
|
]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="llama3"),
|
||||||
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
|
stop_words=["<|eot|>", "<|eom|>"],
|
||||||
|
mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# copied from llama3 template
|
# copied from llama3 template
|
||||||
register_template(
|
register_template(
|
||||||
name="mllama",
|
name="mllama",
|
||||||
|
@ -1111,6 +1111,30 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"Llama-4-Scout-17B-16E": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E",
|
||||||
|
},
|
||||||
|
"Llama-4-Scout-17B-16E-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E-Instruct",
|
||||||
|
},
|
||||||
|
"Llama-4-Maverick-17B-128E": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E",
|
||||||
|
},
|
||||||
|
"Llama-4-Maverick-17B-128E-Instruct": {
|
||||||
|
DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
||||||
|
DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E-Instruct",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="llama4",
|
||||||
|
multimodal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"LLaVA-1.5-7B-Chat": {
|
"LLaVA-1.5-7B-Chat": {
|
||||||
|
@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
|
|||||||
|
|
||||||
def check_dependencies() -> None:
|
def check_dependencies() -> None:
|
||||||
r"""Check the version of the required packages."""
|
r"""Check the version of the required packages."""
|
||||||
check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||||
check_version("datasets>=2.16.0,<=3.4.1")
|
check_version("datasets>=2.16.0,<=3.4.1")
|
||||||
check_version("accelerate>=0.34.0,<=1.5.2")
|
check_version("accelerate>=0.34.0,<=1.5.2")
|
||||||
check_version("peft>=0.14.0,<=0.15.0")
|
check_version("peft>=0.14.0,<=0.15.0")
|
||||||
|
@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
|
|
||||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||||
module: torch.nn.Module = func.__self__
|
if isinstance(func, partial):
|
||||||
|
module: torch.nn.Module = func.func.__self__
|
||||||
|
else:
|
||||||
|
module: torch.nn.Module = func.__self__
|
||||||
|
|
||||||
has_grad = False
|
has_grad = False
|
||||||
if any(param.requires_grad for param in module.parameters()):
|
if any(param.requires_grad for param in module.parameters()):
|
||||||
|
@ -203,6 +203,12 @@ _register_composite_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_composite_model(
|
||||||
|
model_type="llama4",
|
||||||
|
vision_model_keys=["vision_model"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_composite_model(
|
_register_composite_model(
|
||||||
model_type="llava",
|
model_type="llava",
|
||||||
)
|
)
|
||||||
|
@ -243,7 +243,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
|
||||||
mini_batch = {
|
mini_batch = {
|
||||||
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
|
"input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
|
||||||
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
|
"attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
|
||||||
}
|
}
|
||||||
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
|
mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
|
||||||
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
|
mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
|
||||||
|
@ -1 +1,2 @@
|
|||||||
0.9.3.100
|
# change if test fails
|
||||||
|
0.9.3.101
|
||||||
|
Loading…
x
Reference in New Issue
Block a user