From ce0c73c032f1b7c499cc635b3a3193ef4977d4c0 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 29 Sep 2024 20:53:34 +0800 Subject: [PATCH 1/2] Update mm_plugin.py Former-commit-id: 0257a67cb266dcaee8bfb358d88ef2be2403a2f7 --- src/llamafactory/data/mm_plugin.py | 38 ++++++++++-------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index f38031ca..29d80b50 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -158,6 +158,8 @@ class BasePlugin: It holds num_patches == torch.prod(image_grid_thw) """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor) + res = super()._get_mm_inputs(images, [], processor) input_dict = {"images": None} # default key if len(images) != 0: images = self._regularize_images( @@ -174,10 +176,17 @@ class BasePlugin: video_maxlen=getattr(processor, "video_maxlen", 64), ) input_dict["videos"] = videos - if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None: - return image_processor(**input_dict, return_tensors="pt") - else: - return {} + + mm_inputs = {} + if image_processor != video_processor: + if input_dict.get("images") is not None: + mm_inputs.update(image_processor(input_dict["images"], return_tensors="pt")) + if input_dict.get("videos") is not None: + mm_inputs.update(video_processor(input_dict["videos"], return_tensors="pt")) + elif input_dict.get("images") is not None or input_dict.get("videos") is not None: # same processor (qwen2-vl) + mm_inputs.update(image_processor(**input_dict, return_tensors="pt")) + + return mm_inputs def process_messages( self, @@ -365,27 +374,6 @@ class LlavaNextVideoPlugin(BasePlugin): return messages - @override - def _get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos) - video_processor = getattr(processor, "video_processor") - res = super()._get_mm_inputs(images, [], processor) - if len(videos) != 0: - videos = self._regularize_videos( - videos, - image_resolution=getattr(processor, "image_resolution"), - video_fps=getattr(processor, "video_fps"), - video_maxlen=getattr(processor, "video_maxlen"), - ) - video_res = video_processor(videos, return_tensors="pt") - res.update(video_res) - return res - @override def get_mm_inputs( self, From ec793d16de4b2ca7077d8adbcbdeb09bbce2e7ae Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Sun, 29 Sep 2024 20:54:04 +0800 Subject: [PATCH 2/2] Update mm_plugin.py Former-commit-id: ffaea305fc405c9892aa0c9712d98185d9241e69 --- src/llamafactory/data/mm_plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 29d80b50..3684495b 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -159,7 +159,6 @@ class BasePlugin: """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") video_processor: "BaseImageProcessor" = getattr(processor, "video_processor", image_processor) - res = super()._get_mm_inputs(images, [], processor) input_dict = {"images": None} # default key if len(images) != 0: images = self._regularize_images(