From 8f73c75c16aa1e3bca87e6e0ac068df7d5cbbe21 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 14 Jan 2025 16:47:58 +0800 Subject: [PATCH] [model] fix mllama any image (#6637) * fix mllama any image * reorder classes Former-commit-id: 98189c8e4d70bf5f8ee83852a023ed27dfc96900 --- src/llamafactory/data/mm_plugin.py | 156 +++++++++++++++-------------- src/llamafactory/data/template.py | 16 ++- 2 files changed, 91 insertions(+), 81 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index e4447ee5..2739cacc 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -567,6 +567,85 @@ class MiniCPMVPlugin(BasePlugin): return mm_inputs +class MllamaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + num_image_tokens += content.count(IMAGE_PLACEHOLDER) + message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token) + + if len(images) != num_image_tokens: + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") + + return messages + + @override + def _get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: "ProcessorMixin", + **kwargs, + ) -> Dict[str, "torch.Tensor"]: + r""" + Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. + + Returns: + pixel_values: tensor with shape + (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width) + For example, (2, 1, 4, 3, 560, 560). + aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). + aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). + num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). + """ + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + imglens: List[int] = kwargs["imglens"] + images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512)) + batch_images = [] + for image_length in imglens: + batch_images.append(images[:image_length]) + images = images[image_length:] + + return image_processor(batch_images, return_tensors="pt") + + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + batch_ids: Sequence[List[int]], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + mm_inputs = self._get_mm_inputs(images, videos, processor, imglens=imglens) + num_tiles = mm_inputs.pop("num_tiles") + image_token_id = getattr(processor, "image_token_id") + max_image_tiles = getattr(processor.image_processor, "max_image_tiles") + cross_attention_token_mask = [ + get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids + ] + mm_inputs["cross_attention_mask"] = torch.from_numpy( + convert_sparse_cross_attention_mask_to_dense( + cross_attention_token_mask, + num_tiles=num_tiles, + max_num_tiles=max_image_tiles, + length=max(len(input_ids) for input_ids in batch_ids), + ) + ) # shape: (batch_size, length, max_num_images, max_num_tiles) + return mm_inputs + + class PaliGemmaPlugin(BasePlugin): @override def process_messages( @@ -868,92 +947,17 @@ class VideoLlavaPlugin(BasePlugin): return self._get_mm_inputs(images, videos, processor) -class MllamaPlugin(BasePlugin): - @override - def process_messages( - self, - messages: Sequence[Dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: - self._validate_input(images, videos) - num_image_tokens = 0 - messages = deepcopy(messages) - for message in messages: - content = message["content"] - num_image_tokens += content.count(IMAGE_PLACEHOLDER) - message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token) - - if len(images) != num_image_tokens: - raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.") - - return messages - - @override - def _get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - processor: "ProcessorMixin", - ) -> Dict[str, "torch.Tensor"]: - r""" - Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]]. - - Returns: - pixel_values: tensor with shape - (batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width) - For example, (2, 1, 4, 3, 560, 560). - aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1). - aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4). - num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1). - """ - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512)) - return image_processor([[image] for image in images], return_tensors="pt") - - def get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - batch_ids: Sequence[List[int]], - processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos) - if len(images) != len(batch_ids): - raise ValueError("Mllama only supports one image per sample.") - - mm_inputs = self._get_mm_inputs(images, videos, processor) - num_tiles = mm_inputs.pop("num_tiles") - image_token_id = getattr(processor, "image_token_id") - max_image_tiles = getattr(processor.image_processor, "max_image_tiles") - cross_attention_token_mask = [ - get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids - ] - mm_inputs["cross_attention_mask"] = torch.from_numpy( - convert_sparse_cross_attention_mask_to_dense( - cross_attention_token_mask, - num_tiles=num_tiles, - max_num_tiles=max_image_tiles, - length=max(len(input_ids) for input_ids in batch_ids), - ) - ) # shape: (batch_size, length, max_num_images, max_num_tiles) - return mm_inputs - - PLUGINS = { "base": BasePlugin, "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, "minicpm_v": MiniCPMVPlugin, + "mllama": MllamaPlugin, "paligemma": PaliGemmaPlugin, "pixtral": PixtralPlugin, "qwen2_vl": Qwen2vlPlugin, "video_llava": VideoLlavaPlugin, - "mllama": MllamaPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index c9340afc..65c17956 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -759,8 +759,9 @@ _register_template( ) ] ), + format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), - format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"), + format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( @@ -786,8 +787,9 @@ _register_template( ) ] ), + format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), - format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"), + format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( @@ -838,8 +840,9 @@ _register_template( ) ] ), + format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), - format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"), + format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( @@ -1130,15 +1133,18 @@ _register_template( ) ] ), + format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]), format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"), format_observation=StringFormatter( slots=[ ( - "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>" "<|start_header_id|>assistant<|end_header_id|>\n\n" ) ] ), + format_tools=ToolFormatter(tool_format="llama3"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), default_system=( "You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems " @@ -1147,7 +1153,7 @@ _register_template( "After completing your thoughts, you then provide a detailed explanation of the solution process " "in your response." ), - stop_words=["<|eot_id|>"], + stop_words=["<|eot_id|>", "<|eom_id|>"], )