diff --git a/.gitignore b/.gitignore index 1479cb24..630760ed 100644 --- a/.gitignore +++ b/.gitignore @@ -159,6 +159,9 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +# vscode +.vscode/ + # custom .gitignore ms_cache/ hf_cache/ diff --git a/README.md b/README.md index 8724520b..ee464cd9 100644 --- a/README.md +++ b/README.md @@ -186,6 +186,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | +| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | diff --git a/README_zh.md b/README_zh.md index 88c3abb4..35b9522e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -187,6 +187,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 3-3.2](https://huggingface.co/meta-llama) | 1B/3B/8B/70B | llama3 | +| [Llama 3.2 Vision](https://huggingface.co/meta-llama) | 11B/90B | mllama | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/8B/13B/34B/72B/110B | llava_next | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/34B | llava_next_video | diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index eeed9a29..f26d402a 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -164,7 +164,7 @@ class HuggingfaceEngine(BaseEngine): logits_processor=get_logits_processor(), ) - mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, seqlens=[prompt_length], processor=processor) + mm_inputs = template.mm_plugin.get_mm_inputs(**mm_input_dict, batch_ids=[prompt_ids], processor=processor) for key, value in mm_inputs.items(): if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value): # for pixtral inputs value = torch.stack(value) # assume they have same sizes diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 8fa6f0dd..4fc9e803 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): processor: Optional["ProcessorMixin"] = None def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: - batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], [] + batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], [] for feature in features: images = feature.pop("images", None) or [] videos = feature.pop("videos", None) or [] @@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): batch_videos.extend(videos) batch_imglens.append(len(images)) batch_vidlens.append(len(videos)) - batch_seqlens.append(len(feature["input_ids"])) + batch_input_ids.append(feature["input_ids"]) mm_inputs = self.template.mm_plugin.get_mm_inputs( - batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor + batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor ) if "token_type_ids" in mm_inputs: token_type_ids = mm_inputs.pop("token_type_ids") diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 6a174838..b04c2673 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -4,11 +4,12 @@ from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union import numpy as np +import torch from transformers.image_utils import get_image_size, to_numpy_array from typing_extensions import override from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER -from ..extras.packages import is_pillow_available, is_pyav_available +from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than if is_pillow_available(): @@ -20,8 +21,14 @@ if is_pyav_available(): import av +if is_transformers_version_greater_than("4.45.0"): + from transformers.models.mllama.processing_mllama import ( + convert_sparse_cross_attention_mask_to_dense, + get_cross_attention_token_mask, + ) + + if TYPE_CHECKING: - import torch from av.stream import Stream from transformers import PreTrainedTokenizer, ProcessorMixin from transformers.image_processing_utils import BaseImageProcessor @@ -75,8 +82,8 @@ class BasePlugin: Pre-processes a single image. """ image_resolution: int = kwargs.get("image_resolution") - if max(image.width, image.height) > image_resolution: - resize_factor = image_resolution / max(image.width, image.height) + if image.width * image.height > image_resolution: + resize_factor = math.sqrt(image_resolution / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) image = image.resize((width, height), resample=Image.NEAREST) @@ -165,15 +172,15 @@ class BasePlugin: if len(images) != 0: images = self._regularize_images( images, - image_resolution=getattr(processor, "image_resolution", 512), + image_resolution=getattr(processor, "image_resolution", 512 * 512), ) input_dict["images"] = images if len(videos) != 0: videos = self._regularize_videos( videos, - image_resolution=getattr(processor, "video_resolution", 128), - video_fps=getattr(processor, "video_fps", 1.0), + image_resolution=getattr(processor, "video_resolution", 128 * 128), + video_fps=getattr(processor, "video_fps", 2.0), video_maxlen=getattr(processor, "video_maxlen", 64), ) input_dict["videos"] = videos @@ -223,7 +230,7 @@ class BasePlugin: videos: Sequence["VideoInput"], imglens: Sequence[int], vidlens: Sequence[int], - seqlens: Sequence[int], + batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: r""" @@ -234,7 +241,7 @@ class BasePlugin: videos: a list of video inputs, shape (num_videos,) imglens: number of images in each sample, shape (batch_size,) vidlens: number of videos in each sample, shape (batch_size,) - seqlens: number of tokens in each sample, shape (batch_size,) + batch_ids: input ids of samples, shape (batch_size, seq_len) processor: a processor for pre-processing images and videos """ self._validate_input(images, videos) @@ -258,12 +265,12 @@ class LlavaPlugin(BasePlugin): content = message["content"] while IMAGE_PLACEHOLDER in content: num_image_tokens += 1 - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) - message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) + message["content"] = content.replace("{{image}}", self.image_token) if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") return messages @@ -274,7 +281,7 @@ class LlavaPlugin(BasePlugin): videos: Sequence["VideoInput"], imglens: Sequence[int], vidlens: Sequence[int], - seqlens: Sequence[int], + batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) @@ -296,23 +303,27 @@ class LlavaNextPlugin(BasePlugin): mm_inputs = self._get_mm_inputs(images, videos, processor) if "image_sizes" in mm_inputs: image_sizes = iter(mm_inputs["image_sizes"]) + if "pixel_values" in mm_inputs: height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + for message in messages: content = message["content"] - while self.image_token in content: + while IMAGE_PLACEHOLDER in content: image_size = next(image_sizes) orig_height, orig_width = image_size image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) - if processor.vision_feature_select_strategy == "default": + if getattr(processor, "vision_feature_select_strategy") == "default": image_seqlen -= 1 + num_image_tokens += 1 - content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) message["content"] = content.replace("{{image}}", self.image_token) if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + return messages @override @@ -322,12 +333,11 @@ class LlavaNextPlugin(BasePlugin): videos: Sequence["VideoInput"], imglens: Sequence[int], vidlens: Sequence[int], - seqlens: Sequence[int], + batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) - res = self._get_mm_inputs(images, videos, processor) - return res + return self._get_mm_inputs(images, videos, processor) class LlavaNextVideoPlugin(BasePlugin): @@ -340,8 +350,7 @@ class LlavaNextVideoPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: self._validate_input(images, videos) - num_image_tokens = 0 - num_video_tokens = 0 + num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) mm_inputs = self._get_mm_inputs(images, videos, processor) if "pixel_values" in mm_inputs: @@ -349,15 +358,15 @@ class LlavaNextVideoPlugin(BasePlugin): height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) for message in messages: content = message["content"] - - while self.image_token in content: + while IMAGE_PLACEHOLDER in content: image_size = next(image_sizes) orig_height, orig_width = image_size image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) - if processor.vision_feature_select_strategy == "default": + if getattr(processor, "vision_feature_select_strategy") == "default": image_seqlen -= 1 + num_image_tokens += 1 - content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) message["content"] = content.replace("{{image}}", self.image_token) @@ -367,19 +376,19 @@ class LlavaNextVideoPlugin(BasePlugin): num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer - for message in messages: content = message["content"] - while self.video_token in content: + while VIDEO_PLACEHOLDER in content: num_video_tokens += 1 - content = content.replace(self.video_token, "{{video}}", 1) - message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + + message["content"] = content.replace("{{video}}", self.video_token) if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") if len(videos) != num_video_tokens: - raise ValueError(f"The number of videos does not match the number of {IMAGE_PLACEHOLDER} tokens") + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") return messages @@ -390,7 +399,7 @@ class LlavaNextVideoPlugin(BasePlugin): videos: Sequence["VideoInput"], imglens: Sequence[int], vidlens: Sequence[int], - seqlens: Sequence[int], + batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) @@ -418,7 +427,7 @@ class PaliGemmaPlugin(BasePlugin): message["content"] = content.replace("{{image}}", "") if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") return messages @@ -449,10 +458,11 @@ class PaliGemmaPlugin(BasePlugin): videos: Sequence["VideoInput"], imglens: Sequence[int], vidlens: Sequence[int], - seqlens: Sequence[int], + batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) + seqlens = [len(input_ids) for input_ids in batch_ids] mm_inputs = self._get_mm_inputs(images, videos, processor) mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor) return mm_inputs @@ -481,7 +491,7 @@ class PixtralPlugin(BasePlugin): content = message["content"] while IMAGE_PLACEHOLDER in content: if image_input_sizes is None: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") + raise ValueError("Cannot get image input sizes.") image_size = image_input_sizes[0][num_image_tokens] height, width = image_size @@ -497,7 +507,7 @@ class PixtralPlugin(BasePlugin): message["content"] = content if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") return messages @@ -508,7 +518,7 @@ class PixtralPlugin(BasePlugin): videos: Sequence["VideoInput"], imglens: Sequence[int], vidlens: Sequence[int], - seqlens: Sequence[int], + batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) @@ -592,10 +602,10 @@ class Qwen2vlPlugin(BasePlugin): message["content"] = content if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") if len(videos) != num_video_tokens: - raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens") + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") return messages @@ -606,7 +616,7 @@ class Qwen2vlPlugin(BasePlugin): videos: Sequence["VideoInput"], imglens: Sequence[int], vidlens: Sequence[int], - seqlens: Sequence[int], + batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) @@ -623,42 +633,45 @@ class VideoLlavaPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: self._validate_input(images, videos) - num_image_tokens = 0 - num_video_tokens = 0 + num_image_tokens, num_video_tokens = 0, 0 messages = deepcopy(messages) mm_inputs = self._get_mm_inputs(images, videos, processor) num_frames = 0 - exist_images = "pixel_values_images" in mm_inputs - exist_videos = "pixel_values_videos" in mm_inputs - if exist_videos or exist_images: - if exist_images: + has_images = "pixel_values_images" in mm_inputs + has_videos = "pixel_values_videos" in mm_inputs + if has_images or has_videos: + if has_images: height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) num_frames = 1 - if exist_videos: + + if has_videos: pixel_values_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) height, width = get_image_size(pixel_values_video[0]) num_frames = pixel_values_video.shape[0] # frame dim is always after batch dim + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 video_seqlen = image_seqlen * num_frames - if processor.vision_feature_select_strategy == "default": + if getattr(processor, "vision_feature_select_strategy") == "default": image_seqlen -= 1 + for message in messages: content = message["content"] - while self.image_token in content: + while IMAGE_PLACEHOLDER in content: num_image_tokens += 1 - content = content.replace(self.image_token, "{{image}}", 1) - while self.video_token in content: - num_video_tokens += 1 - content = content.replace(self.video_token, "{{video}}", 1) + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}" * image_seqlen, 1) - content = content.replace("{{image}}", self.image_token * image_seqlen) - message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) + while VIDEO_PLACEHOLDER in content: + num_video_tokens += 1 + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}" * video_seqlen, 1) + + content = content.replace("{{image}}", self.image_token) + message["content"] = content.replace("{{video}}", self.video_token) if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {self.image_token} tokens") + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") if len(videos) != num_video_tokens: - raise ValueError(f"The number of videos does not match the number of {self.video_token} tokens") + raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.") return messages @@ -669,13 +682,86 @@ class VideoLlavaPlugin(BasePlugin): videos: Sequence["VideoInput"], imglens: Sequence[int], vidlens: Sequence[int], - seqlens: Sequence[int], + batch_ids: Sequence[List[int]], processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) return self._get_mm_inputs(images, videos, processor) +class MllamaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + num_image_tokens += content.count(IMAGE_PLACEHOLDER) + message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token) + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + return messages + + @override + def _get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: "ProcessorMixin", + ) -> Dict[str, "torch.Tensor"]: + r""" + Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. + + Returns: + pixel_values: tensor with shape + (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width) + For example, (2, 1, 4, 3, 560, 560). + aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). + aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). + num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512)) + return image_processor([[image] for image in images], return_tensors="pt") + + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + batch_ids: Sequence[List[int]], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + if len(images) != len(batch_ids): + raise ValueError("Mllama only supports one image per sample.") + + mm_inputs = self._get_mm_inputs(images, videos, processor) + num_tiles = mm_inputs.pop("num_tiles") + image_token_id = getattr(processor, "image_token_id") + max_image_tiles = getattr(processor.image_processor, "max_image_tiles") + cross_attention_token_mask = [ + get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids + ] + mm_inputs["cross_attention_mask"] = convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=max_image_tiles, + length=max(len(input_ids) for input_ids in batch_ids), + ) + return mm_inputs + + PLUGINS = { "base": BasePlugin, "llava": LlavaPlugin, @@ -685,6 +771,7 @@ PLUGINS = { "pixtral": PixtralPlugin, "qwen2_vl": Qwen2vlPlugin, "video_llava": VideoLlavaPlugin, + "mllama": MllamaPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 27ffe9e8..6054b7a6 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -762,6 +762,33 @@ _register_template( ) +_register_template( + name="mllama", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_observation=StringFormatter( + slots=[ + ( + "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot_id|>"], + replace_eos=True, + replace_jinja_template=False, + mm_plugin=get_mm_plugin(name="mllama", image_token="<|image|>"), +) + + _register_template( name="llava", format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index f6738f81..5fdb33c2 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -855,6 +855,22 @@ register_model_group( ) +register_model_group( + models={ + "Llama-3.2-11B-Vision-Instruct": { + DownloadSource.DEFAULT: "meta-llama/Llama-3.2-11B-Vision-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-11B-Vision-Instruct", + }, + "Llama-3.2-90B-Vision-Instruct": { + DownloadSource.DEFAULT: "meta-llama/Llama-3.2-90B-Vision-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Llama-3.2-90B-Vision-Instruct", + }, + }, + template="mllama", + vision=True, +) + + register_model_group( models={ "LLaVA-1.5-7B-Chat": { diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 98066714..44b9bb8a 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -75,8 +75,8 @@ def is_starlette_available(): @lru_cache -def is_transformers_version_greater_than_4_43(): - return _get_package_version("transformers") >= version.parse("4.43.0") +def is_transformers_version_greater_than(content: str): + return _get_package_version("transformers") >= version.parse(content) @lru_cache diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 2f980142..dcfa117d 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -59,12 +59,12 @@ class ProcessorArguments: """ image_resolution: int = field( - default=512, - metadata={"help": "Keeps the height or width of image below this resolution."}, + default=512 * 512, + metadata={"help": "Keeps the number of pixels of image below this resolution."}, ) video_resolution: int = field( - default=128, - metadata={"help": "Keeps the height or width of video below this resolution."}, + default=128 * 128, + metadata={"help": "Keeps the number of pixels of video below this resolution."}, ) video_fps: float = field( default=2.0, diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 74adb015..96a7b40e 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -35,7 +35,7 @@ from transformers.utils.versions import require_version from ...extras import logging from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN -from ...extras.packages import is_transformers_version_greater_than_4_43 +from ...extras.packages import is_transformers_version_greater_than if TYPE_CHECKING: @@ -209,7 +209,7 @@ def llama_flash_attention_2_forward( if attention_mask is not None: attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) - if is_transformers_version_greater_than_4_43(): + if is_transformers_version_greater_than("4.43.0"): from transformers.modeling_flash_attention_utils import _flash_attention_forward attn_output: "torch.Tensor" = _flash_attention_forward( diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 52cf9eb3..5f4b747e 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -34,13 +34,15 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) forbidden_modules.add("output_layer") elif model_type == "internlm2": forbidden_modules.add("output") - elif model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: + elif model_type in ["llava", "llava_next", "llava_next_video", "mllama", "paligemma", "video_llava"]: forbidden_modules.add("multi_modal_projector") elif model_type == "qwen2_vl": forbidden_modules.add("merger") if freeze_vision_tower: - if model_type == "qwen2_vl": + if model_type == "mllama": + forbidden_modules.add("vision_model") + elif model_type == "qwen2_vl": forbidden_modules.add("visual") else: forbidden_modules.add("vision_tower") diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 899f346e..15921981 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -45,7 +45,7 @@ from transformers.utils.versions import require_version from ...extras import logging from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN -from ...extras.packages import is_transformers_version_greater_than_4_43 +from ...extras.packages import is_transformers_version_greater_than if TYPE_CHECKING: @@ -115,7 +115,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor def _patch_for_block_diag_attn(model_type: str) -> None: require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") - if is_transformers_version_greater_than_4_43(): + if is_transformers_version_greater_than("4.43.0"): import transformers.modeling_flash_attention_utils transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 1ac46e06..04f1eae8 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -26,7 +26,7 @@ from ...extras import logging if TYPE_CHECKING: - from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel + from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel, ProcessorMixin from ...hparams import FinetuningArguments, ModelArguments @@ -163,19 +163,21 @@ def get_image_seqlen(config: "PretrainedConfig") -> int: return image_seqlen -def get_patch_size(config: "PretrainedConfig") -> int: +def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int: r""" Computes the patch size of the vit. """ - patch_size = getattr(config.vision_config, "patch_size", -1) + patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1)) return patch_size -def get_vision_feature_select_strategy(config: "PretrainedConfig") -> int: +def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int: r""" Get the vision_feature_select_strategy. """ - vision_feature_select_strategy = getattr(config, "vision_feature_select_strategy", "default") + vision_feature_select_strategy = getattr( + config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default") + ) return vision_feature_select_strategy @@ -189,6 +191,8 @@ def patch_target_modules( if finetuning_args.freeze_vision_tower: if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "pixtral", "video_llava"]: return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + elif model_type == "mllama": + return "^(?!.*vision_model).*(?:{}).*".format("|".join(target_modules)) elif model_type == "qwen2_vl": return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules)) else: diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 20046565..66823af7 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -66,11 +66,11 @@ def patch_processor( setattr(processor, "tokenizer", tokenizer) setattr(processor, "image_seqlen", get_image_seqlen(config)) setattr(processor, "image_resolution", model_args.image_resolution) - setattr(processor, "patch_size", get_patch_size(config)) + setattr(processor, "patch_size", get_patch_size(config, processor)) setattr(processor, "video_resolution", model_args.video_resolution) setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_maxlen", model_args.video_maxlen) - setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config)) + setattr(processor, "vision_feature_select_strategy", get_vision_feature_select_strategy(config, processor)) def patch_config(