mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
fix qwen2vl plugin (#6855)
Former-commit-id: 40048ab77a8b25a91a844800f0f1e880b84548cd
This commit is contained in:
parent
f70208e1c0
commit
28037c7834
@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -117,16 +118,19 @@ class BasePlugin:
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _get_video_sample_frames(self, video_stream: "Stream", **kwargs) -> int:
|
def _get_video_sample_indices(self, video_stream: "Stream", **kwargs) -> List[int]:
|
||||||
r"""
|
r"""
|
||||||
Computes video sample frames according to fps.
|
Computes video sample indices according to fps.
|
||||||
"""
|
"""
|
||||||
video_fps: float = kwargs["video_fps"]
|
video_fps: float = kwargs["video_fps"]
|
||||||
video_maxlen: int = kwargs["video_maxlen"]
|
video_maxlen: int = kwargs["video_maxlen"]
|
||||||
total_frames = video_stream.frames
|
total_frames = video_stream.frames
|
||||||
sample_frames = float(video_stream.duration * video_stream.time_base) * video_fps
|
if total_frames == 0: # infinite video
|
||||||
|
return np.linspace(0, video_maxlen - 1, video_maxlen).astype(np.int32)
|
||||||
|
|
||||||
|
sample_frames = math.floor(float(video_stream.duration * video_stream.time_base) * video_fps)
|
||||||
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
sample_frames = min(total_frames, video_maxlen, sample_frames)
|
||||||
return math.floor(sample_frames)
|
return np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
||||||
|
|
||||||
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
def _regularize_images(self, images: Sequence["ImageInput"], **kwargs) -> List["ImageObject"]:
|
||||||
r"""
|
r"""
|
||||||
@ -159,9 +163,7 @@ class BasePlugin:
|
|||||||
for video in videos:
|
for video in videos:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
total_frames = video_stream.frames
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
|
||||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
|
||||||
frames: List["ImageObject"] = []
|
frames: List["ImageObject"] = []
|
||||||
container.seek(0)
|
container.seek(0)
|
||||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
@ -228,7 +230,10 @@ class BasePlugin:
|
|||||||
video_fps=getattr(processor, "video_fps", 2.0),
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
if "videos" in inspect.signature(video_processor.preprocess).parameters: # qwen2vl processor
|
||||||
|
mm_inputs.update(video_processor(images=None, videos=videos, return_tensors="pt"))
|
||||||
|
else:
|
||||||
|
mm_inputs.update(video_processor(videos, return_tensors="pt"))
|
||||||
|
|
||||||
if len(audios) != 0:
|
if len(audios) != 0:
|
||||||
audios = self._regularize_audios(
|
audios = self._regularize_audios(
|
||||||
@ -1011,9 +1016,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
for video in videos:
|
for video in videos:
|
||||||
container = av.open(video, "r")
|
container = av.open(video, "r")
|
||||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||||
total_frames = video_stream.frames
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
sample_frames = self._get_video_sample_frames(video_stream, **kwargs)
|
|
||||||
sample_indices = np.linspace(0, total_frames - 1, sample_frames).astype(np.int32)
|
|
||||||
frames: List["ImageObject"] = []
|
frames: List["ImageObject"] = []
|
||||||
container.seek(0)
|
container.seek(0)
|
||||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
|
@ -2156,6 +2156,18 @@ register_model_group(
|
|||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
|
"Qwen2-VL-2B": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B",
|
||||||
|
},
|
||||||
|
"Qwen2-VL-7B": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-7B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-7B",
|
||||||
|
},
|
||||||
|
"Qwen2-VL-72B": {
|
||||||
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-72B",
|
||||||
|
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-72B",
|
||||||
|
},
|
||||||
"Qwen2-VL-2B-Instruct": {
|
"Qwen2-VL-2B-Instruct": {
|
||||||
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
|
DownloadSource.DEFAULT: "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct",
|
DownloadSource.MODELSCOPE: "Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user