mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
[data] fix qwen_2_5_vl video processing (#6868)
* fix qwen_2_5_vl video processing * Update mm_plugin.py * Update mm_plugin.py --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 9153a7bd832cdae84b63a4d7d1f2b12239e84b61
This commit is contained in:
parent
703bb9cc18
commit
188f22d8a7
@ -230,9 +230,9 @@ class BasePlugin:
|
|||||||
video_fps=getattr(processor, "video_fps", 2.0),
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
if "videos" in inspect.signature(video_processor.preprocess).parameters: # qwen2vl processor
|
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
||||||
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
||||||
else:
|
else: # for llava_next_video
|
||||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||||
|
|
||||||
if len(audios) != 0:
|
if len(audios) != 0:
|
||||||
@ -1017,8 +1017,8 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> Tuple[List[List["ImageObject"]], List[float]]:
|
||||||
results = []
|
results, fps_per_video = [], []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@ -1034,8 +1034,41 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
|
|
||||||
frames = self._regularize_images(frames, **kwargs)
|
frames = self._regularize_images(frames, **kwargs)
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
if video_stream.duration is None:
|
||||||
|
fps_per_video.append(2.0)
|
||||||
|
else:
|
||||||
|
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
return results
|
return results, fps_per_video
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
audios: Sequence["AudioInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||||
|
mm_inputs = {}
|
||||||
|
if len(images) != 0:
|
||||||
|
images = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_resolution=getattr(processor, "image_resolution", 768 * 768),
|
||||||
|
)
|
||||||
|
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
videos, fps_per_video = self._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_resolution=getattr(processor, "video_resolution", 256 * 256),
|
||||||
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
|
)
|
||||||
|
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
|
||||||
|
mm_inputs["fps_per_video"] = fps_per_video
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -1101,12 +1134,10 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos, audios)
|
self._validate_input(images, videos, audios)
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
|
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs:
|
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and fps_per_video:
|
||||||
video_fps = getattr(processor, "video_fps", 2.0)
|
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
|
||||||
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / video_fps] * len(
|
|
||||||
mm_inputs["video_grid_thw"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user