mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-11-05 02:12:14 +08:00
[data] Fix Qwen3VL plugin (#9297)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Yaowei Zheng <hiyouga@buaa.edu.cn> Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
parent
9c0d033a15
commit
129e918106
@ -16,6 +16,7 @@ import gc
|
|||||||
import json
|
import json
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import av
|
||||||
import fire
|
import fire
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
@ -33,6 +34,14 @@ if is_vllm_available():
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
|
|
||||||
|
def _need_video_kwargs(template):
|
||||||
|
NEEDED_TEMPLATE = ["qwen3_vl", "glm4v"]
|
||||||
|
if any(t in template for t in NEEDED_TEMPLATE):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def vllm_infer(
|
def vllm_infer(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
adapter_name_or_path: str = None,
|
adapter_name_or_path: str = None,
|
||||||
@ -132,6 +141,7 @@ def vllm_infer(
|
|||||||
|
|
||||||
# Store all results in these lists
|
# Store all results in these lists
|
||||||
all_prompts, all_preds, all_labels = [], [], []
|
all_prompts, all_preds, all_labels = [], [], []
|
||||||
|
need_video_kwargs = _need_video_kwargs(template)
|
||||||
|
|
||||||
# Add batch process to avoid the issue of too many files opened
|
# Add batch process to avoid the issue of too many files opened
|
||||||
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
for i in tqdm(range(0, len(train_dataset), batch_size), desc="Processing batched inference"):
|
||||||
@ -147,6 +157,7 @@ def vllm_infer(
|
|||||||
)["images"]
|
)["images"]
|
||||||
}
|
}
|
||||||
elif batch["videos"][j] is not None:
|
elif batch["videos"][j] is not None:
|
||||||
|
video_metadata, video_metadata_kwargs = None, None
|
||||||
video = batch["videos"][j]
|
video = batch["videos"][j]
|
||||||
multi_modal_data = {
|
multi_modal_data = {
|
||||||
"video": template_obj.mm_plugin._regularize_videos(
|
"video": template_obj.mm_plugin._regularize_videos(
|
||||||
@ -157,6 +168,25 @@ def vllm_infer(
|
|||||||
video_maxlen=video_maxlen,
|
video_maxlen=video_maxlen,
|
||||||
)["videos"]
|
)["videos"]
|
||||||
}
|
}
|
||||||
|
if need_video_kwargs:
|
||||||
|
container = av.open(video[0], "r")
|
||||||
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
|
sampling_indices = template_obj.mm_plugin._get_video_sample_indices(
|
||||||
|
video_stream, video_fps, video_maxlen
|
||||||
|
)
|
||||||
|
total_frames = video_stream.frames
|
||||||
|
video_metadata_kwargs = {
|
||||||
|
"fps": getattr(tokenizer_module["processor"], "video_fps", 24.0),
|
||||||
|
"do_sample_frames": False,
|
||||||
|
"total_num_frames": total_frames,
|
||||||
|
}
|
||||||
|
video_metadata = dict(
|
||||||
|
fps=video_fps,
|
||||||
|
frames_indices=sampling_indices,
|
||||||
|
total_num_frames=total_frames,
|
||||||
|
video_backend="opencv",
|
||||||
|
)
|
||||||
|
multi_modal_data["video"] = (multi_modal_data["video"], video_metadata)
|
||||||
elif batch["audios"][j] is not None:
|
elif batch["audios"][j] is not None:
|
||||||
audio = batch["audios"][j]
|
audio = batch["audios"][j]
|
||||||
audio_data = template_obj.mm_plugin._regularize_audios(
|
audio_data = template_obj.mm_plugin._regularize_audios(
|
||||||
@ -167,7 +197,11 @@ def vllm_infer(
|
|||||||
else:
|
else:
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
|
|
||||||
vllm_inputs.append({"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data})
|
vllm_input_data = {"prompt_token_ids": batch["input_ids"][j], "multi_modal_data": multi_modal_data}
|
||||||
|
if "video_metadata_kwargs" in locals() and video_metadata_kwargs is not None:
|
||||||
|
vllm_input_data["mm_processor_kwargs"] = video_metadata_kwargs
|
||||||
|
|
||||||
|
vllm_inputs.append(vllm_input_data)
|
||||||
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
|
prompts.append(tokenizer.decode(batch["input_ids"][j], skip_special_tokens=skip_special_tokens))
|
||||||
labels.append(
|
labels.append(
|
||||||
tokenizer.decode(
|
tokenizer.decode(
|
||||||
|
|||||||
@ -31,7 +31,7 @@ from transformers.models.mllama.processing_mllama import (
|
|||||||
convert_sparse_cross_attention_mask_to_dense,
|
convert_sparse_cross_attention_mask_to_dense,
|
||||||
get_cross_attention_token_mask,
|
get_cross_attention_token_mask,
|
||||||
)
|
)
|
||||||
from typing_extensions import override
|
from typing_extensions import NotRequired, override
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.packages import (
|
from ..extras.packages import (
|
||||||
@ -77,6 +77,18 @@ if TYPE_CHECKING:
|
|||||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||||
AudioInput = Union[str, BinaryIO, NDArray]
|
AudioInput = Union[str, BinaryIO, NDArray]
|
||||||
|
|
||||||
|
class RegularizedImageOutput(TypedDict):
|
||||||
|
images: list[ImageObject]
|
||||||
|
|
||||||
|
class RegularizedVideoOutput(TypedDict):
|
||||||
|
videos: list[list[ImageObject]]
|
||||||
|
durations: list[float]
|
||||||
|
fps_per_video: NotRequired[list[float]]
|
||||||
|
|
||||||
|
class RegularizedAudioOutput(TypedDict):
|
||||||
|
audios: list[NDArray]
|
||||||
|
sampling_rates: list[float]
|
||||||
|
|
||||||
class MMProcessor(ProcessorMixin):
|
class MMProcessor(ProcessorMixin):
|
||||||
patch_size: int
|
patch_size: int
|
||||||
image_seq_length: int
|
image_seq_length: int
|
||||||
@ -244,7 +256,7 @@ class MMPluginMixin:
|
|||||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||||
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
|
|
||||||
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> dict[str, list["ImageObject"]]:
|
def _regularize_images(self, images: list["ImageInput"], **kwargs) -> "RegularizedImageOutput":
|
||||||
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
r"""Regularize images to avoid error. Including reading and pre-processing."""
|
||||||
results = []
|
results = []
|
||||||
for image in images:
|
for image in images:
|
||||||
@ -265,9 +277,10 @@ class MMPluginMixin:
|
|||||||
|
|
||||||
return {"images": results}
|
return {"images": results}
|
||||||
|
|
||||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> dict[str, list[list["ImageObject"]]]:
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
r"""Regularizes videos to avoid error. Including reading, resizing and converting."""
|
||||||
results = []
|
results = []
|
||||||
|
durations = []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
if _check_video_is_nested_images(video):
|
if _check_video_is_nested_images(video):
|
||||||
@ -275,6 +288,7 @@ class MMPluginMixin:
|
|||||||
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
if not is_valid_image(frame) and not isinstance(frame, dict) and not os.path.exists(frame):
|
||||||
raise ValueError("Invalid image found in video frames.")
|
raise ValueError("Invalid image found in video frames.")
|
||||||
frames = video
|
frames = video
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@ -284,14 +298,19 @@ class MMPluginMixin:
|
|||||||
if frame_idx in sample_indices:
|
if frame_idx in sample_indices:
|
||||||
frames.append(frame.to_image())
|
frames.append(frame.to_image())
|
||||||
|
|
||||||
|
if video_stream.duration is None:
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
|
else:
|
||||||
|
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results}
|
return {"videos": results, "durations": durations}
|
||||||
|
|
||||||
def _regularize_audios(
|
def _regularize_audios(
|
||||||
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
self, audios: list["AudioInput"], sampling_rate: float, **kwargs
|
||||||
) -> dict[str, Union[list["NDArray"], list[float]]]:
|
) -> "RegularizedAudioOutput":
|
||||||
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
r"""Regularizes audios to avoid error. Including reading and resampling."""
|
||||||
results, sampling_rates = [], []
|
results, sampling_rates = [], []
|
||||||
for audio in audios:
|
for audio in audios:
|
||||||
@ -1418,10 +1437,8 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _regularize_videos(
|
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||||
self, videos: list["VideoInput"], **kwargs
|
results, fps_per_video, durations = [], [], []
|
||||||
) -> dict[str, Union[list[list["ImageObject"]], list[float]]]:
|
|
||||||
results, fps_per_video = [], []
|
|
||||||
for video in videos:
|
for video in videos:
|
||||||
frames: list[ImageObject] = []
|
frames: list[ImageObject] = []
|
||||||
if _check_video_is_nested_images(video):
|
if _check_video_is_nested_images(video):
|
||||||
@ -1431,6 +1448,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
frames = video
|
frames = video
|
||||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
@ -1442,8 +1460,10 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
|
|
||||||
if video_stream.duration is None:
|
if video_stream.duration is None:
|
||||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||||
|
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||||
else:
|
else:
|
||||||
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
fps_per_video.append(len(sample_indices) / float(video_stream.duration * video_stream.time_base))
|
||||||
|
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||||
|
|
||||||
if len(frames) % 2 != 0:
|
if len(frames) % 2 != 0:
|
||||||
frames.append(frames[-1])
|
frames.append(frames[-1])
|
||||||
@ -1451,7 +1471,7 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results, "fps_per_video": fps_per_video}
|
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
@ -1565,8 +1585,8 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": len(video), "total_num_frames": len(video)}
|
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video)}
|
||||||
for video in videos["videos"]
|
for video, duration in zip(videos["videos"], videos["durations"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(
|
||||||
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
|
video_processor(videos=videos["videos"], video_metadata=video_metadata, return_metadata=True)
|
||||||
@ -1622,27 +1642,27 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while VIDEO_PLACEHOLDER in content:
|
||||||
metadata = video_metadata[idx]
|
if self.expand_mm_tokens:
|
||||||
timestamps = processor._calculate_timestamps(
|
metadata = video_metadata[idx]
|
||||||
metadata.frames_indices,
|
timestamps = processor._calculate_timestamps(
|
||||||
metadata.fps,
|
metadata.frames_indices,
|
||||||
video_processor.merge_size,
|
metadata.fps,
|
||||||
)
|
video_processor.merge_size,
|
||||||
video_structure = ""
|
|
||||||
for frame_index in range(num_frames):
|
|
||||||
video_seqlen = (
|
|
||||||
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
|
|
||||||
if self.expand_mm_tokens
|
|
||||||
else 1
|
|
||||||
)
|
)
|
||||||
timestamp_sec = timestamps[frame_index]
|
video_structure = ""
|
||||||
frame_structure = (
|
for frame_index in range(num_frames):
|
||||||
f"<{timestamp_sec:.1f} seconds>"
|
video_seqlen = (
|
||||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
video_grid_thw[num_video_tokens][1:].prod() // video_merge_length
|
||||||
)
|
if self.expand_mm_tokens
|
||||||
video_structure += frame_structure
|
else 1
|
||||||
|
)
|
||||||
if not self.expand_mm_tokens:
|
timestamp_sec = timestamps[frame_index]
|
||||||
|
frame_structure = (
|
||||||
|
f"<{timestamp_sec:.1f} seconds>"
|
||||||
|
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
||||||
|
)
|
||||||
|
video_structure += frame_structure
|
||||||
|
else:
|
||||||
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
||||||
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
||||||
@ -1684,7 +1704,8 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
|||||||
)
|
)
|
||||||
# prepare video metadata
|
# prepare video metadata
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
|
{"fps": 2, "duration": duration, "total_frames": len(video)}
|
||||||
|
for video, duration in zip(video_data["videos"], video_data["durations"])
|
||||||
]
|
]
|
||||||
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user