diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 2b1d9fe5..68416fdf 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -168,7 +168,7 @@ class HuggingfaceEngine(BaseEngine): for key, value in mm_inputs.items(): value = value if isinstance(value, torch.Tensor) else torch.tensor(value) gen_kwargs[key] = value.to(model.device) - + return gen_kwargs, prompt_length @staticmethod diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 0e59ec0b..6716527c 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -325,6 +325,14 @@ class PaliGemmaPlugin(BasePlugin): return mm_inputs class PixtralPlugin(BasePlugin): + # @override + # def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject": + # image = super()._preprocess_image(image, **kwargs) + # UP_SIZE = (512,512) + # image = image.resize(UP_SIZE, resample=Image.NEAREST) + + # return image + @override def process_messages( self, @@ -340,15 +348,22 @@ class PixtralPlugin(BasePlugin): self._validate_input(images, videos) num_image_tokens = 0 - image_input_sizes = self._get_mm_inputs(images, videos, processor)["image_sizes"] + img_kwargs = self._get_mm_inputs(images, videos, processor) + image_input_sizes = None + + if img_kwargs.get("pixel_values") is not None: + image_input_sizes = img_kwargs["image_sizes"] + messages = deepcopy(messages) - print(image_input_sizes[0], messages) for message in messages: content = message["content"] img_id = 0 while IMAGE_PLACEHOLDER in content: - # only support one image for one time? - image_size = image_input_sizes[0][0] + + if image_input_sizes is None: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + image_size = image_input_sizes[0][img_id] height, width = image_size num_height_tokens = height // patch_size num_width_tokens = width // patch_size @@ -359,7 +374,7 @@ class PixtralPlugin(BasePlugin): replace_tokens = [item for sublist in replace_tokens for item in sublist] replace_tokens[-1] = image_end_token replace_str = "".join(replace_tokens) - content.replace(IMAGE_PLACEHOLDER, replace_str, 1) + content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1) img_id += 1 num_image_tokens += 1 @@ -383,7 +398,16 @@ class PixtralPlugin(BasePlugin): ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) - return self._get_mm_inputs(images, videos, processor) + mm_inputs = self._get_mm_inputs(images, videos, processor) + if mm_inputs.get('image_sizes'): + del mm_inputs['image_sizes'] + # TODO fix this type error + # if isinstance(mm_inputs.get("pixel_values"), list): #List[List[torch.tensor]] -> [B C W H] + # recommend for batch==1 for one gpu or it will rise the error of BatchEncoding. + mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")[0][0].unsqueeze(0) + # mm_inputs["pixel_values"] = mm_inputs.get("pixel_values") + + return mm_inputs class Qwen2vlPlugin(BasePlugin): @override diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index ef075cf9..e3f6a99d 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -917,16 +917,6 @@ register_model_group( template="mistral", ) -register_model_group( - models={ - "Pixtral-12B-2409": { - DownloadSource.DEFAULT: "mistral-community/pixtral-12b", - DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b", - } - }, - template="mistral" -) - register_model_group( models={ @@ -1067,6 +1057,18 @@ register_model_group( ) +register_model_group( + models={ + "Pixtral-12B-2409": { + DownloadSource.DEFAULT: "mistral-community/pixtral-12b", + DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b", + } + }, + template="mistral", + vision=True +) + + register_model_group( models={ "Qwen-1.8B": {