remove some unnecessary if conditions

Former-commit-id: de06e2678e2168586614242f65939c5772e78774
This commit is contained in:
Kingsley 2024-09-28 02:14:06 +08:00
parent b76116bb6c
commit 66e473d519
3 changed files with 43 additions and 17 deletions

View File

@ -325,6 +325,14 @@ class PaliGemmaPlugin(BasePlugin):
return mm_inputs return mm_inputs
class PixtralPlugin(BasePlugin): class PixtralPlugin(BasePlugin):
# @override
# def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
# image = super()._preprocess_image(image, **kwargs)
# UP_SIZE = (512,512)
# image = image.resize(UP_SIZE, resample=Image.NEAREST)
# return image
@override @override
def process_messages( def process_messages(
self, self,
@ -340,15 +348,22 @@ class PixtralPlugin(BasePlugin):
self._validate_input(images, videos) self._validate_input(images, videos)
num_image_tokens = 0 num_image_tokens = 0
image_input_sizes = self._get_mm_inputs(images, videos, processor)["image_sizes"] img_kwargs = self._get_mm_inputs(images, videos, processor)
image_input_sizes = None
if img_kwargs.get("pixel_values") is not None:
image_input_sizes = img_kwargs["image_sizes"]
messages = deepcopy(messages) messages = deepcopy(messages)
print(image_input_sizes[0], messages)
for message in messages: for message in messages:
content = message["content"] content = message["content"]
img_id = 0 img_id = 0
while IMAGE_PLACEHOLDER in content: while IMAGE_PLACEHOLDER in content:
# only support one image for one time?
image_size = image_input_sizes[0][0] if image_input_sizes is None:
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
image_size = image_input_sizes[0][img_id]
height, width = image_size height, width = image_size
num_height_tokens = height // patch_size num_height_tokens = height // patch_size
num_width_tokens = width // patch_size num_width_tokens = width // patch_size
@ -359,7 +374,7 @@ class PixtralPlugin(BasePlugin):
replace_tokens = [item for sublist in replace_tokens for item in sublist] replace_tokens = [item for sublist in replace_tokens for item in sublist]
replace_tokens[-1] = image_end_token replace_tokens[-1] = image_end_token
replace_str = "".join(replace_tokens) replace_str = "".join(replace_tokens)
content.replace(IMAGE_PLACEHOLDER, replace_str, 1) content = content.replace(IMAGE_PLACEHOLDER, replace_str, 1)
img_id += 1 img_id += 1
num_image_tokens += 1 num_image_tokens += 1
@ -383,7 +398,16 @@ class PixtralPlugin(BasePlugin):
) -> Dict[str, Union[List[int], "torch.Tensor"]]: ) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos) self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor) mm_inputs = self._get_mm_inputs(images, videos, processor)
if mm_inputs.get('image_sizes'):
del mm_inputs['image_sizes']
# TODO fix this type error
# if isinstance(mm_inputs.get("pixel_values"), list): #List[List[torch.tensor]] -> [B C W H]
# recommend for batch==1 for one gpu or it will rise the error of BatchEncoding.
mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")[0][0].unsqueeze(0)
# mm_inputs["pixel_values"] = mm_inputs.get("pixel_values")
return mm_inputs
class Qwen2vlPlugin(BasePlugin): class Qwen2vlPlugin(BasePlugin):
@override @override

View File

@ -917,16 +917,6 @@ register_model_group(
template="mistral", template="mistral",
) )
register_model_group(
models={
"Pixtral-12B-2409": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="mistral"
)
register_model_group( register_model_group(
models={ models={
@ -1067,6 +1057,18 @@ register_model_group(
) )
register_model_group(
models={
"Pixtral-12B-2409": {
DownloadSource.DEFAULT: "mistral-community/pixtral-12b",
DownloadSource.MODELSCOPE: "AI-ModelScope/pixtral-12b",
}
},
template="mistral",
vision=True
)
register_model_group( register_model_group(
models={ models={
"Qwen-1.8B": { "Qwen-1.8B": {