mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-28 11:14:18 +08:00
[data] Fix Qwen3VL plugin (#9297)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn> Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
parent
9c0d033a15
commit
129e918106
@ -16,6 +16,7 @@ import gc
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import fire
|
||||
from tqdm import tqdm
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
@ -33,6 +34,14 @@ if is_vllm_available():
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
def _need_video_kwargs(template):
|
||||
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
|
||||
if any(t in template for t in NEEDED_TEMPLATE):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def vllm_infer(
|
||||
model_name_or_path: str,
|
||||
adapter_name_or_path: str = None,
|
||||
@ -132,6 +141,7 @@ def vllm_infer(
|
||||
|
||||
# Store all results in these lists
|
||||
all_prompts, all_preds, all_labels = [], [], []
|
||||
need_video_kwargs = _need_video_kwargs(template)
|
||||
|
||||
# Add batch process to avoid the issue of too many files opened
|
||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||
@ -147,6 +157,7 @@ def vllm_infer(
|
||||
)["images"]
|
||||
}
|
||||
elif batch["videos"][j] is not None:
|
||||
video_metadata, video_metadata_kwargs = None, None
|
||||
video = batch["videos"][j]
|
||||
multi_modal_data = {
|
||||
"video": template_obj.mm_plugin._regularize_videos(
|
||||
@ -157,6 +168,25 @@ def vllm_infer(
|
||||
video_maxlen=video_maxlen,
|
||||
)["videos"]
|
||||
}
|
||||
if need_video_kwargs:
|
||||
container = av.open(video[0], "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
|
||||
video_stream, video_fps, video_maxlen
|
||||
)
|
||||
total_frames = video_stream.frames
|
||||
video_metadata_kwargs = {
|
||||
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
|
||||
"do_sample_frames": False,
|
||||
"total_num_frames": total_frames,
|
||||
}
|
||||
video_metadata = dict(
|
||||
fps=video_fps,
|
||||
frames_indices=sampling_indices,
|
||||
total_num_frames=total_frames,
|
||||
video_backend="opencv",
|
||||
)
|
||||
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||
elif batch["audios"][j] is not None:
|
||||
audio = batch["audios"][j]
|
||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||
@ -167,7 +197,11 @@ def vllm_infer(
|
||||
else:
|
||||
multi_modal_data = None
|
||||
|
||||
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
|
||||
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
||||
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
||||
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||
|
||||
vllm_inputs.append(vllm_input_data)
|
||||
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
|
||||
labels.append(
|
||||
tokenizer.decode(
|
||||
|
||||
@ -31,7 +31,7 @@ from transformers.models.mllama.processing_mllama import (
|
||||
convert_sparse_cross_attention_mask_to_dense,
|
||||
get_cross_attention_token_mask,
|
||||
)
|
||||
from typing_extensions import override
|
||||
from typing_extensions import NotRequired, override
|
||||
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import (
|
||||
@ -77,6 +77,18 @@ if TYPE_CHECKING:
|
||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||
AudioInput = Union[str, BinaryIO, NDArray]
|
||||
|
||||
class RegularizedImageOutput(TypedDict):
|
||||
images: list[ImageObject]
|
||||
|
||||
class RegularizedVideoOutput(TypedDict):
|
||||
videos: list[list[ImageObject]]
|
||||
durations: list[float]
|
||||
fps_per_video: NotRequired[list[float]]
|
||||
|
||||
class RegularizedAudioOutput(TypedDict):
|
||||
audios: list[NDArray]
|
||||
sampling_rates: list[float]
|
||||
|
||||
class MMProcessor(ProcessorMixin):
|
||||
patch_size: int
|
||||
image_seq_length: int
|
||||
@ -244,7 +256,7 @@ class MMPluginMixin:
|
||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||
|
||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
|
||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||
results = []
|
||||
for image in images:
|
||||
@ -265,9 +277,10 @@ class MMPluginMixin:
|
||||
|
||||
return {"images": results}
|
||||
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||
results = []
|
||||
durations = []
|
||||
for video in videos:
|
||||
frames: list[ImageObject] = []
|
||||
if _check_video_is_nested_images(video):
|
||||
@ -275,6 +288,7 @@ class MMPluginMixin:
|
||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||
raise ValueError("Invalid image found in video frames.")
|
||||
frames = video
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
else:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
@ -284,14 +298,19 @@ class MMPluginMixin:
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
if video_stream.duration is None:
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
else:
|
||||
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
|
||||
return {"videos": results}
|
||||
return {"videos": results, "durations": durations}
|
||||
|
||||
def _regularize_audios(
|
||||
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
||||
) -> dict[str, Union[list["NDArray"], list[float]]]:
|
||||
) -> "RegularizedAudioOutput":
|
||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||
results, sampling_rates = [], []
|
||||
for audio in audios:
|
||||
@ -1418,10 +1437,8 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
return image
|
||||
|
||||
@override
|
||||
def _regularize_videos(
|
||||
self, videos: list["VideoInput"], **kwargs
|
||||
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
|
||||
results, fps_per_video = [], []
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||
results, fps_per_video, durations = [], [], []
|
||||
for video in videos:
|
||||
frames: list[ImageObject] = []
|
||||
if _check_video_is_nested_images(video):
|
||||
@ -1431,6 +1448,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
frames = video
|
||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
else:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
@ -1442,8 +1460,10 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
|
||||
if video_stream.duration is None:
|
||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
else:
|
||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||
|
||||
if len(frames) % 2 != 0:
|
||||
frames.append(frames[-1])
|
||||
@ -1451,7 +1471,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
|
||||
return {"videos": results, "fps_per_video": fps_per_video}
|
||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
@ -1565,8 +1585,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
video_metadata = [
|
||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
|
||||
for video in videos["videos"]
|
||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
|
||||
for video, duration in zip(videos["videos"], videos["durations"])
|
||||
]
|
||||
mm_inputs.update(
|
||||
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
|
||||
@ -1622,27 +1642,27 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
metadata = video_metadata[idx]
|
||||
timestamps = processor._calculate_timestamps(
|
||||
metadata.frames_indices,
|
||||
metadata.fps,
|
||||
video_processor.merge_size,
|
||||
)
|
||||
video_structure = ""
|
||||
for frame_index in range(num_frames):
|
||||
video_seqlen = (
|
||||
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
|
||||
if self.expand_mm_tokens
|
||||
else 1
|
||||
if self.expand_mm_tokens:
|
||||
metadata = video_metadata[idx]
|
||||
timestamps = processor._calculate_timestamps(
|
||||
metadata.frames_indices,
|
||||
metadata.fps,
|
||||
video_processor.merge_size,
|
||||
)
|
||||
timestamp_sec = timestamps[frame_index]
|
||||
frame_structure = (
|
||||
f"<{timestamp_sec:.1f} seconds>"
|
||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
||||
)
|
||||
video_structure += frame_structure
|
||||
|
||||
if not self.expand_mm_tokens:
|
||||
video_structure = ""
|
||||
for frame_index in range(num_frames):
|
||||
video_seqlen = (
|
||||
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
|
||||
if self.expand_mm_tokens
|
||||
else 1
|
||||
)
|
||||
timestamp_sec = timestamps[frame_index]
|
||||
frame_structure = (
|
||||
f"<{timestamp_sec:.1f} seconds>"
|
||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
||||
)
|
||||
video_structure += frame_structure
|
||||
else:
|
||||
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
||||
@ -1684,7 +1704,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
||||
)
|
||||
# prepare video metadata
|
||||
video_metadata = [
|
||||
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
|
||||
{"fps": 2, "duration": duration, "total_frames": len(video)}
|
||||
for video, duration in zip(video_data["videos"], video_data["durations"])
|
||||
]
|
||||
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user