From 2c5f912e16baebc0f7caf4b31696f9620c293f5f Mon Sep 17 00:00:00 2001 From: Kingsley Date: Mon, 14 Oct 2024 16:55:59 +0800 Subject: [PATCH] remove bs condition Former-commit-id: 962b9730a7a2940a0d4e5c76d1fe41d0fef76547 --- src/llamafactory/data/mm_plugin.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index edca48a7..b9e7bc3b 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -477,7 +477,7 @@ class PixtralPlugin(BasePlugin): if image_input_sizes is None: raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) - + image_size = image_input_sizes[0][img_id] height, width = image_size num_height_tokens = height // patch_size @@ -500,7 +500,7 @@ class PixtralPlugin(BasePlugin): raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) return messages - + @override def get_mm_inputs( self, @@ -516,11 +516,6 @@ class PixtralPlugin(BasePlugin): if mm_inputs.get("image_sizes"): mm_inputs.pop("image_sizes") - if isinstance(mm_inputs.get("pixel_values"), list) and len(mm_inputs.get("pixel_values")[0]) >= 2: - raise ValueError("Now it only supports batchsize=1 on per gpu due to `List[tensor]` can not pack into BachEncoding") - - mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")[0][0].unsqueeze(0) - return mm_inputs class Qwen2vlPlugin(BasePlugin):