mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[model] add llama4 (#7611)
This commit is contained in:
		
							parent
							
								
									d4cfa9507e
								
							
						
					
					
						commit
						831e7f1cfd
					
				@ -1,6 +1,4 @@
 | 
			
		||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' and sys_platform != 'darwin'
 | 
			
		||||
transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' and sys_platform != 'darwin'
 | 
			
		||||
transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;sys_platform == 'darwin'
 | 
			
		||||
transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
 | 
			
		||||
datasets>=2.16.0,<=3.4.1
 | 
			
		||||
accelerate>=0.34.0,<=1.5.2
 | 
			
		||||
peft>=0.14.0,<=0.15.0
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										39
									
								
								scripts/convert_ckpt/tiny_llama4.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								scripts/convert_ckpt/tiny_llama4.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,39 @@
 | 
			
		||||
# Copyright 2025 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    vision_config = Llama4VisionConfig(
 | 
			
		||||
        hidden_size=1408,
 | 
			
		||||
        image_size=336,
 | 
			
		||||
        intermediate_size=5632,
 | 
			
		||||
        num_attention_heads=16,
 | 
			
		||||
        num_hidden_layers=4,
 | 
			
		||||
        vision_output_dim=4096,
 | 
			
		||||
    )
 | 
			
		||||
    text_config = Llama4TextConfig(
 | 
			
		||||
        hidden_size=512,
 | 
			
		||||
        intermediate_size=1024,
 | 
			
		||||
        intermediate_size_mlp=1024,
 | 
			
		||||
        num_hidden_layers=4,
 | 
			
		||||
        num_attention_heads=8,
 | 
			
		||||
        num_key_value_heads=2,
 | 
			
		||||
        head_dim=512 // 8,
 | 
			
		||||
        num_local_experts=2,
 | 
			
		||||
    )
 | 
			
		||||
    config = Llama4Config(vision_config=vision_config, text_config=text_config)
 | 
			
		||||
    model = Llama4ForConditionalGeneration._from_config(config)
 | 
			
		||||
    model.save_pretrained("tiny-llama4")
 | 
			
		||||
@ -19,7 +19,7 @@ Level:
 | 
			
		||||
 | 
			
		||||
Dependency graph:
 | 
			
		||||
  main:
 | 
			
		||||
    transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0
 | 
			
		||||
    transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0
 | 
			
		||||
    datasets>=2.16.0,<=3.4.1
 | 
			
		||||
    accelerate>=0.34.0,<=1.5.2
 | 
			
		||||
    peft>=0.14.0,<=0.15.0
 | 
			
		||||
 | 
			
		||||
@ -466,6 +466,73 @@ class Gemma3Plugin(BasePlugin):
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Llama4Plugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: list[dict[str, str]],
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> list[dict[str, str]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        if self.expand_mm_tokens:
 | 
			
		||||
            mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
            if "pixel_values" in mm_inputs:
 | 
			
		||||
                image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:]
 | 
			
		||||
                num_patches_per_chunk = int(
 | 
			
		||||
                    (image_height // processor.patch_size)
 | 
			
		||||
                    * (image_width // processor.patch_size)
 | 
			
		||||
                    // processor.downsample_ratio
 | 
			
		||||
                )
 | 
			
		||||
                aspect_ratios = mm_inputs.pop("aspect_ratios")
 | 
			
		||||
 | 
			
		||||
        num_image_tokens = 0
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        for message in messages:
 | 
			
		||||
            content = message["content"]
 | 
			
		||||
            placeholder_count = content.count(IMAGE_PLACEHOLDER)
 | 
			
		||||
            if self.expand_mm_tokens:
 | 
			
		||||
                prompt_splits = content.split(IMAGE_PLACEHOLDER)
 | 
			
		||||
                new_content = []
 | 
			
		||||
                for local_image_index, split_part in enumerate(prompt_splits):
 | 
			
		||||
                    new_content.append(split_part)
 | 
			
		||||
                    if local_image_index < placeholder_count:
 | 
			
		||||
                        tokens_for_this_image = processor._prompt_split_image(
 | 
			
		||||
                            aspect_ratios[num_image_tokens], num_patches_per_chunk
 | 
			
		||||
                        )
 | 
			
		||||
                        num_image_tokens += 1
 | 
			
		||||
                        new_content.append(tokens_for_this_image)
 | 
			
		||||
 | 
			
		||||
                content = "".join(new_content)
 | 
			
		||||
 | 
			
		||||
            message["content"] = content
 | 
			
		||||
 | 
			
		||||
        if len(images) != num_image_tokens:
 | 
			
		||||
            raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: list["ImageInput"],
 | 
			
		||||
        videos: list["VideoInput"],
 | 
			
		||||
        audios: list["AudioInput"],
 | 
			
		||||
        imglens: list[int],
 | 
			
		||||
        vidlens: list[int],
 | 
			
		||||
        audlens: list[int],
 | 
			
		||||
        batch_ids: list[list[int]],
 | 
			
		||||
        processor: Optional["MMProcessor"],
 | 
			
		||||
    ) -> dict[str, Union[list[int], "torch.Tensor"]]:
 | 
			
		||||
        self._validate_input(processor, images, videos, audios)
 | 
			
		||||
        mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
 | 
			
		||||
        mm_inputs.pop("aspect_ratios", None)
 | 
			
		||||
        return mm_inputs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LlavaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
@ -1485,6 +1552,7 @@ class VideoLlavaPlugin(BasePlugin):
 | 
			
		||||
PLUGINS = {
 | 
			
		||||
    "base": BasePlugin,
 | 
			
		||||
    "gemma3": Gemma3Plugin,
 | 
			
		||||
    "llama4": Llama4Plugin,
 | 
			
		||||
    "llava": LlavaPlugin,
 | 
			
		||||
    "llava_next": LlavaNextPlugin,
 | 
			
		||||
    "llava_next_video": LlavaNextVideoPlugin,
 | 
			
		||||
 | 
			
		||||
@ -968,6 +968,26 @@ register_template(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_template(
 | 
			
		||||
    name="llama4",
 | 
			
		||||
    format_user=StringFormatter(
 | 
			
		||||
        slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"]
 | 
			
		||||
    ),
 | 
			
		||||
    format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]),
 | 
			
		||||
    format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]),
 | 
			
		||||
    format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"),
 | 
			
		||||
    format_observation=StringFormatter(
 | 
			
		||||
        slots=[
 | 
			
		||||
            "<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"
 | 
			
		||||
        ]
 | 
			
		||||
    ),
 | 
			
		||||
    format_tools=ToolFormatter(tool_format="llama3"),
 | 
			
		||||
    format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
 | 
			
		||||
    stop_words=["<|eot|>", "<|eom|>"],
 | 
			
		||||
    mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"),
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# copied from llama3 template
 | 
			
		||||
register_template(
 | 
			
		||||
    name="mllama",
 | 
			
		||||
 | 
			
		||||
@ -1111,6 +1111,30 @@ register_model_group(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "Llama-4-Scout-17B-16E": {
 | 
			
		||||
            DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E",
 | 
			
		||||
        },
 | 
			
		||||
        "Llama-4-Scout-17B-16E-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E-Instruct",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E-Instruct",
 | 
			
		||||
        },
 | 
			
		||||
        "Llama-4-Maverick-17B-128E": {
 | 
			
		||||
            DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E",
 | 
			
		||||
        },
 | 
			
		||||
        "Llama-4-Maverick-17B-128E-Instruct": {
 | 
			
		||||
            DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E-Instruct",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E-Instruct",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="llama4",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
register_model_group(
 | 
			
		||||
    models={
 | 
			
		||||
        "LLaVA-1.5-7B-Chat": {
 | 
			
		||||
 | 
			
		||||
@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
 | 
			
		||||
 | 
			
		||||
def check_dependencies() -> None:
 | 
			
		||||
    r"""Check the version of the required packages."""
 | 
			
		||||
    check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
 | 
			
		||||
    check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
 | 
			
		||||
    check_version("datasets>=2.16.0,<=3.4.1")
 | 
			
		||||
    check_version("accelerate>=0.34.0,<=1.5.2")
 | 
			
		||||
    check_version("peft>=0.14.0,<=0.15.0")
 | 
			
		||||
 | 
			
		||||
@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
 | 
			
		||||
 | 
			
		||||
    @wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
 | 
			
		||||
    def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
 | 
			
		||||
        module: torch.nn.Module = func.__self__
 | 
			
		||||
        if isinstance(func, partial):
 | 
			
		||||
            module: torch.nn.Module = func.func.__self__
 | 
			
		||||
        else:
 | 
			
		||||
            module: torch.nn.Module = func.__self__
 | 
			
		||||
 | 
			
		||||
        has_grad = False
 | 
			
		||||
        if any(param.requires_grad for param in module.parameters()):
 | 
			
		||||
 | 
			
		||||
@ -203,6 +203,12 @@ _register_composite_model(
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="llama4",
 | 
			
		||||
    vision_model_keys=["vision_model"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_register_composite_model(
 | 
			
		||||
    model_type="llava",
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -243,7 +243,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
 | 
			
		||||
            for idx in range(0, self.config.batch_size, self.config.mini_batch_size):
 | 
			
		||||
                mini_batch = {
 | 
			
		||||
                    "input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size],
 | 
			
		||||
                    "attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size]
 | 
			
		||||
                    "attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size],
 | 
			
		||||
                }
 | 
			
		||||
                mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch)
 | 
			
		||||
                mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses)
 | 
			
		||||
 | 
			
		||||
@ -1 +1,2 @@
 | 
			
		||||
0.9.3.100
 | 
			
		||||
# change if test fails
 | 
			
		||||
0.9.3.101
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user