diff --git a/requirements.txt b/requirements.txt index bae72eff..4b142057 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ -transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.*;python_version<'3.10' and sys_platform != 'darwin' -transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10' and sys_platform != 'darwin' -transformers>=4.41.2,<=4.49.0,!=4.46.*,!=4.47.*,!=4.48.*;sys_platform == 'darwin' +transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0 datasets>=2.16.0,<=3.4.1 accelerate>=0.34.0,<=1.5.2 peft>=0.14.0,<=0.15.0 diff --git a/scripts/convert_ckpt/tiny_llama4.py b/scripts/convert_ckpt/tiny_llama4.py new file mode 100644 index 00000000..2a96cfa6 --- /dev/null +++ b/scripts/convert_ckpt/tiny_llama4.py @@ -0,0 +1,39 @@ +# Copyright 2025 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from transformers import Llama4Config, Llama4ForConditionalGeneration, Llama4TextConfig, Llama4VisionConfig + + +if __name__ == "__main__": + vision_config = Llama4VisionConfig( + hidden_size=1408, + image_size=336, + intermediate_size=5632, + num_attention_heads=16, + num_hidden_layers=4, + vision_output_dim=4096, + ) + text_config = Llama4TextConfig( + hidden_size=512, + intermediate_size=1024, + intermediate_size_mlp=1024, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=512 // 8, + num_local_experts=2, + ) + config = Llama4Config(vision_config=vision_config, text_config=text_config) + model = Llama4ForConditionalGeneration._from_config(config) + model.save_pretrained("tiny-llama4") diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index daf58b97..953ad8bf 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -19,7 +19,7 @@ Level: Dependency graph: main: - transformers>=4.41.2,<=4.50.0,!=4.46.*,!=4.47.*,!=4.48.0 + transformers>=4.41.2,<=4.51.0,!=4.46.*,!=4.47.*,!=4.48.0 datasets>=2.16.0,<=3.4.1 accelerate>=0.34.0,<=1.5.2 peft>=0.14.0,<=0.15.0 diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index d5114eec..c5133539 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -466,6 +466,73 @@ class Gemma3Plugin(BasePlugin): return mm_inputs +@dataclass +class Llama4Plugin(BasePlugin): + @override + def process_messages( + self, + messages: list[dict[str, str]], + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + processor: Optional["MMProcessor"], + ) -> list[dict[str, str]]: + self._validate_input(processor, images, videos, audios) + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + if "pixel_values" in mm_inputs: + image_height, image_width = mm_inputs["pixel_values"][0].shape[-2:] + num_patches_per_chunk = int( + (image_height // processor.patch_size) + * (image_width // processor.patch_size) + // processor.downsample_ratio + ) + aspect_ratios = mm_inputs.pop("aspect_ratios") + + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + placeholder_count = content.count(IMAGE_PLACEHOLDER) + if self.expand_mm_tokens: + prompt_splits = content.split(IMAGE_PLACEHOLDER) + new_content = [] + for local_image_index, split_part in enumerate(prompt_splits): + new_content.append(split_part) + if local_image_index < placeholder_count: + tokens_for_this_image = processor._prompt_split_image( + aspect_ratios[num_image_tokens], num_patches_per_chunk + ) + num_image_tokens += 1 + new_content.append(tokens_for_this_image) + + content = "".join(new_content) + + message["content"] = content + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + return messages + + @override + def get_mm_inputs( + self, + images: list["ImageInput"], + videos: list["VideoInput"], + audios: list["AudioInput"], + imglens: list[int], + vidlens: list[int], + audlens: list[int], + batch_ids: list[list[int]], + processor: Optional["MMProcessor"], + ) -> dict[str, Union[list[int], "torch.Tensor"]]: + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("aspect_ratios", None) + return mm_inputs + + @dataclass class LlavaPlugin(BasePlugin): @override @@ -1485,6 +1552,7 @@ class VideoLlavaPlugin(BasePlugin): PLUGINS = { "base": BasePlugin, "gemma3": Gemma3Plugin, + "llama4": Llama4Plugin, "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index a29bf959..8732685e 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -968,6 +968,26 @@ register_template( ) +register_template( + name="llama4", + format_user=StringFormatter( + slots=["<|header_start|>user<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n"] + ), + format_assistant=StringFormatter(slots=["{{content}}<|eot|>"]), + format_system=StringFormatter(slots=["<|header_start|>system<|header_end|>\n\n{{content}}<|eot|>"]), + format_function=FunctionFormatter(slots=["{{content}}<|eot|>"], tool_format="llama3"), + format_observation=StringFormatter( + slots=[ + "<|header_start|>ipython<|header_end|>\n\n{{content}}<|eot|><|header_start|>assistant<|header_end|>\n\n" + ] + ), + format_tools=ToolFormatter(tool_format="llama3"), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot|>", "<|eom|>"], + mm_plugin=get_mm_plugin(name="llama4", image_token="<|image|>"), +) + + # copied from llama3 template register_template( name="mllama", diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 702e9bc6..a1829145 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1111,6 +1111,30 @@ register_model_group( ) +register_model_group( + models={ + "Llama-4-Scout-17B-16E": { + DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E", + }, + "Llama-4-Scout-17B-16E-Instruct": { + DownloadSource.DEFAULT: "meta-llama/Llama-4-Scout-17B-16E-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Scout-17B-16E-Instruct", + }, + "Llama-4-Maverick-17B-128E": { + DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E", + }, + "Llama-4-Maverick-17B-128E-Instruct": { + DownloadSource.DEFAULT: "meta-llama/Llama-4-Maverick-17B-128E-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-4-Maverick-17B-128E-Instruct", + }, + }, + template="llama4", + multimodal=True, +) + + register_model_group( models={ "LLaVA-1.5-7B-Chat": { diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 22fa490a..7880aa39 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -89,7 +89,7 @@ def check_version(requirement: str, mandatory: bool = False) -> None: def check_dependencies() -> None: r"""Check the version of the required packages.""" - check_version("transformers>=4.41.2,<=4.50.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") + check_version("transformers>=4.41.2,<=4.51.0,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") check_version("datasets>=2.16.0,<=3.4.1") check_version("accelerate>=0.34.0,<=1.5.2") check_version("peft>=0.14.0,<=0.15.0") diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 051f6b04..28e2a795 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable @wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",)) def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs): - module: torch.nn.Module = func.__self__ + if isinstance(func, partial): + module: torch.nn.Module = func.func.__self__ + else: + module: torch.nn.Module = func.__self__ has_grad = False if any(param.requires_grad for param in module.parameters()): diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index c69bc690..75b71976 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -203,6 +203,12 @@ _register_composite_model( ) +_register_composite_model( + model_type="llama4", + vision_model_keys=["vision_model"], +) + + _register_composite_model( model_type="llava", ) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index b285a003..14497459 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -243,7 +243,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): for idx in range(0, self.config.batch_size, self.config.mini_batch_size): mini_batch = { "input_ids": batch["input_ids"][idx : idx + self.config.mini_batch_size], - "attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size] + "attention_mask": batch["attention_mask"][idx : idx + self.config.mini_batch_size], } mini_batch_queries, mini_batch_responses = self.get_inputs(mini_batch) mini_batch_rewards = self.get_rewards(mini_batch_queries, mini_batch_responses) diff --git a/tests/version.txt b/tests/version.txt index 1f4eec6c..399e7891 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1 +1,2 @@ -0.9.3.100 +# change if test fails +0.9.3.101