mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-21 20:36:02 +08:00
[model] gemma4 (#10346)
This commit is contained in:
@@ -607,6 +607,194 @@ class Gemma3nPlugin(Gemma3Plugin):
|
||||
return messages
|
||||
|
||||
|
||||
@dataclass
|
||||
class Gemma4Plugin(BasePlugin):
|
||||
r"""Plugin for the Gemma4 multimodal model."""
|
||||
|
||||
@override
|
||||
def _regularize_videos(self, videos: list["VideoInput"], **kwargs) -> "RegularizedVideoOutput":
|
||||
r"""Regularize videos, also tracking per-video FPS and frame indices for timestamp generation."""
|
||||
results, fps_per_video, durations, frames_indices = [], [], [], []
|
||||
for video in videos:
|
||||
frames: list[ImageObject] = []
|
||||
if _check_video_is_nested_images(video):
|
||||
frames = video
|
||||
fps_per_video.append(kwargs.get("video_fps", 2.0))
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
frames_indices.append(list(range(len(frames))))
|
||||
else:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
original_fps = float(video_stream.average_rate)
|
||||
# for correctly calculate timestamps
|
||||
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices])
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
frames.append(frame.to_image())
|
||||
|
||||
if video_stream.duration is None:
|
||||
durations.append(len(frames) / kwargs.get("video_fps", 2.0))
|
||||
else:
|
||||
durations.append(float(video_stream.duration * video_stream.time_base))
|
||||
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
|
||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: "MMProcessor",
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
image_processor = getattr(processor, "image_processor", None)
|
||||
video_processor = getattr(processor, "video_processor", None)
|
||||
feature_extractor = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
|
||||
if len(images) != 0:
|
||||
regularized = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)["images"]
|
||||
mm_inputs.update(image_processor(regularized, return_tensors="pt"))
|
||||
|
||||
if len(videos) != 0:
|
||||
video_data = self._regularize_videos(
|
||||
videos,
|
||||
image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
|
||||
image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
|
||||
video_fps=getattr(processor, "video_fps", 2.0),
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
video_metadata = [
|
||||
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||
for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"])
|
||||
]
|
||||
mm_inputs.update(
|
||||
video_processor(
|
||||
videos=video_data["videos"],
|
||||
video_metadata=video_metadata,
|
||||
return_tensors="pt",
|
||||
return_metadata=True,
|
||||
do_sample_frames=False,
|
||||
)
|
||||
)
|
||||
|
||||
if len(audios) != 0: # only for gemma4n
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
)["audios"]
|
||||
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
audios,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: list[dict[str, str]],
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> list[dict[str, str]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
self._validate_messages(messages, images, videos, audios)
|
||||
messages = deepcopy(messages)
|
||||
|
||||
boi_token: str = getattr(processor, "boi_token")
|
||||
eoi_token: str = getattr(processor, "eoi_token")
|
||||
boa_token: str = getattr(processor, "boa_token")
|
||||
eoa_token: str = getattr(processor, "eoa_token")
|
||||
image_token: str = getattr(processor, "image_token")
|
||||
video_token: str = getattr(processor, "video_token")
|
||||
audio_token: str = getattr(processor, "audio_token")
|
||||
|
||||
if self.expand_mm_tokens:
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
num_image_soft_tokens: list[int] = list(
|
||||
mm_inputs.get("num_soft_tokens_per_image", [getattr(processor, "image_seq_length", 256)] * len(images))
|
||||
)
|
||||
num_video_soft_tokens: list[int] = list(mm_inputs.get("num_soft_tokens_per_video", [1] * len(videos)))
|
||||
video_metadata = mm_inputs.get("video_metadata", [])
|
||||
else:
|
||||
num_image_soft_tokens = [1] * len(images)
|
||||
num_video_soft_tokens = [1] * len(videos)
|
||||
video_metadata = [None] * len(videos)
|
||||
|
||||
audio_iter = iter(audios)
|
||||
image_iter = iter(num_image_soft_tokens)
|
||||
video_iter = iter(zip(num_video_soft_tokens, video_metadata))
|
||||
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
n = next(image_iter)
|
||||
content = content.replace(IMAGE_PLACEHOLDER, f"{boi_token}{image_token * n}{eoi_token}", 1)
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
num_soft_tokens_per_frame, metadata = next(video_iter)
|
||||
if self.expand_mm_tokens:
|
||||
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
|
||||
frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs]
|
||||
video_str = " ".join(frame_strs)
|
||||
else:
|
||||
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_str, 1)
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
current_audio = next(audio_iter)
|
||||
if self.expand_mm_tokens:
|
||||
num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate)
|
||||
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
|
||||
else:
|
||||
audio_str = f"{boa_token}{audio_token}{eoa_token}"
|
||||
|
||||
content = content.replace(AUDIO_PLACEHOLDER, audio_str, 1)
|
||||
|
||||
message["content"] = content
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: list["ImageInput"],
|
||||
videos: list["VideoInput"],
|
||||
audios: list["AudioInput"],
|
||||
imglens: list[int],
|
||||
vidlens: list[int],
|
||||
audlens: list[int],
|
||||
batch_ids: list[list[int]],
|
||||
processor: Optional["MMProcessor"],
|
||||
) -> dict[str, Union[list[int], "torch.Tensor"]]:
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
# Pop metadata keys that must not be passed to the model.
|
||||
for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata",
|
||||
"_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"):
|
||||
mm_inputs.pop(key, None)
|
||||
|
||||
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class InternVLPlugin(BasePlugin):
|
||||
@override
|
||||
@@ -1505,7 +1693,7 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
else:
|
||||
container = av.open(video, "r")
|
||||
video_stream = next(stream for stream in container.streams if stream.type == "video")
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
original_fps = float(video_stream.average_rate)
|
||||
# for qwen3vl video timestamp calculation
|
||||
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
|
||||
@@ -1642,7 +1830,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
video_metadata = [
|
||||
{"fps": getattr(processor, "video_fps", 24.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
|
||||
]
|
||||
mm_inputs.update(
|
||||
@@ -1683,7 +1871,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
image_grid_thw = mm_inputs.get("image_grid_thw", [])
|
||||
video_grid_thw = mm_inputs.get("video_grid_thw", [])
|
||||
num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
|
||||
video_metadata = mm_inputs.get("video_metadata", {})
|
||||
video_metadata = mm_inputs.get("video_metadata", [])
|
||||
|
||||
else:
|
||||
image_grid_thw = [None] * len(images)
|
||||
@@ -2206,8 +2394,9 @@ PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"ernie_vl": ErnieVLPlugin,
|
||||
"gemma3": Gemma3Plugin,
|
||||
"glm4v": GLM4VPlugin,
|
||||
"gemma3n": Gemma3nPlugin,
|
||||
"gemma4": Gemma4Plugin,
|
||||
"glm4v": GLM4VPlugin,
|
||||
"intern_vl": InternVLPlugin,
|
||||
"kimi_vl": KimiVLPlugin,
|
||||
"llama4": Llama4Plugin,
|
||||
|
||||
Reference in New Issue
Block a user