mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[data] fix qwen_2_5_vl video processing (#6868)
* fix qwen_2_5_vl video processing * Update mm_plugin.py * Update mm_plugin.py --------- Co-authored-by: hoshi-hiyouga <hiyouga@buaa.edu.cn> Former-commit-id: 9153a7bd832cdae84b63a4d7d1f2b12239e84b61
This commit is contained in:
parent
703bb9cc18
commit
188f22d8a7
@ -230,9 +230,9 @@ class BasePlugin:
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
if "videos" in inspect.signature(video_processor.preprocess).parameters: # qwen2vl processor
|
||||
if "videos" in inspect.signature(video_processor.preprocess).parameters: # for qwen2_vl and video_llava
|
||||
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
||||
else:
|
||||
else: # for llava_next_video
|
||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||
|
||||
if len(audios) != 0:
|
||||
@ -1017,8 +1017,8 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
return image
|
||||
|
||||
@override
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||
results = []
|
||||
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> Tuple[List[List["ImageObject"]], List[float]]:
|
||||
results, fps_per_video = [], []
|
||||
for video in videos:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
@ -1034,8 +1034,41 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)
|
||||
results.append(frames)
|
||||
if video_stream.duration is None:
|
||||
fps_per_video.append(2.0)
|
||||
else:
|
||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||
|
||||
return results
|
||||
return results, fps_per_video
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_resolution=getattr(processor, "image_resolution", 768 * 768),
|
||||
)
|
||||
mm_inputs.update(image_processor(images, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
videos, fps_per_video = self._regularize_videos(
|
||||
videos,
|
||||
image_resolution=getattr(processor, "video_resolution", 256 * 256),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
mm_inputs.update(image_processor(images=None, videos=videos, return_tensors="pt"))
|
||||
mm_inputs["fps_per_video"] = fps_per_video
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
@ -1101,12 +1134,10 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs:
|
||||
video_fps = getattr(processor, "video_fps", 2.0)
|
||||
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / video_fps] * len(
|
||||
mm_inputs["video_grid_thw"]
|
||||
)
|
||||
if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and fps_per_video:
|
||||
mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / fps for fps in fps_per_video]
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user