[model] add qwen2.5 vl models (#6779)

This commit is contained in:
hoshi-hiyouga
2025-01-31 03:00:29 +08:00
committed by GitHub
parent 15357cdad9
commit 999c7c8fe0
8 changed files with 77 additions and 30 deletions

View File

@@ -135,12 +135,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features: Dict[str, "torch.Tensor"] = super().__call__(features)
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
input_ids=features["input_ids"],
image_grid_thw=mm_inputs.get("image_grid_thw", None),
video_grid_thw=mm_inputs.get("video_grid_thw", None),
attention_mask=features["attention_mask"],
)
rope_index_kwargs = {
"input_ids": features["input_ids"],
"image_grid_thw": mm_inputs.get("image_grid_thw"),
"video_grid_thw": mm_inputs.get("video_grid_thw"),
"attention_mask": features["attention_mask"],
}
if "second_per_grid_ts" in mm_inputs:
rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts")
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")

View File

@@ -178,16 +178,16 @@ class BasePlugin:
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
image_resolution=getattr(processor, "image_resolution", 768 * 768),
)
input_dict["images"] = images
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
image_resolution=getattr(processor, "video_resolution", 256 * 256),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
input_dict["videos"] = videos
@@ -501,7 +501,7 @@ class MiniCPMVPlugin(BasePlugin):
if len(images) != 0:
images = self._regularize_images(
images,
image_resolution=getattr(processor, "image_resolution", 512 * 512),
image_resolution=getattr(processor, "image_resolution", 768 * 768),
)
if "valid_image_nums_ls" in kwargs:
valid_image_nums_ls = kwargs["valid_image_nums_ls"]
@@ -521,9 +521,9 @@ class MiniCPMVPlugin(BasePlugin):
if len(videos) != 0:
videos = self._regularize_videos(
videos,
image_resolution=getattr(processor, "video_resolution", 128 * 128),
image_resolution=getattr(processor, "video_resolution", 256 * 256),
video_fps=getattr(processor, "video_fps", 2.0),
video_maxlen=getattr(processor, "video_maxlen", 64),
video_maxlen=getattr(processor, "video_maxlen", 128),
)
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
mm_inputs.update(video_inputs)
@@ -610,7 +610,7 @@ class MllamaPlugin(BasePlugin):
"""
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
imglens: List[int] = kwargs["imglens"]
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768))
batch_images = []
for image_length in imglens:
batch_images.append(images[:image_length])
@@ -875,7 +875,15 @@ class Qwen2vlPlugin(BasePlugin):
processor: Optional["ProcessorMixin"],
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
self._validate_input(images, videos)
return self._get_mm_inputs(images, videos, processor)
mm_inputs = self._get_mm_inputs(images, videos, processor)
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs:
video_fps = getattr(processor, "video_fps", 2.0)
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / video_fps] * len(
mm_inputs["video_grid_thw"]
)
return mm_inputs
class VideoLlavaPlugin(BasePlugin):