mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-28 10:58:54 +08:00
[misc] code lint (#10439)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -642,7 +642,12 @@ class Gemma4Plugin(BasePlugin):
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
|
||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
||||
return {
|
||||
"videos": results,
|
||||
"fps_per_video": fps_per_video,
|
||||
"durations": durations,
|
||||
"frames_indices": frames_indices,
|
||||
}
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
@@ -674,8 +679,15 @@ class Gemma4Plugin(BasePlugin):
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
video_metadata = [
|
||||
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||
for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"])
|
||||
{
|
||||
"fps": getattr(processor, "video_fps", 2.0),
|
||||
"duration": duration,
|
||||
"total_num_frames": len(video),
|
||||
"frames_indices": sample_indices,
|
||||
}
|
||||
for video, duration, sample_indices in zip(
|
||||
video_data["videos"], video_data["durations"], video_data["frames_indices"]
|
||||
)
|
||||
]
|
||||
mm_inputs.update(
|
||||
video_processor(
|
||||
@@ -687,7 +699,7 @@ class Gemma4Plugin(BasePlugin):
|
||||
)
|
||||
)
|
||||
|
||||
if len(audios) != 0: # only for gemma4n
|
||||
if len(audios) != 0: # only for gemma4n
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(processor, "audio_sampling_rate", 16000),
|
||||
@@ -695,11 +707,11 @@ class Gemma4Plugin(BasePlugin):
|
||||
|
||||
mm_inputs.update(
|
||||
feature_extractor(
|
||||
audios,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
audios,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@@ -751,7 +763,10 @@ class Gemma4Plugin(BasePlugin):
|
||||
num_soft_tokens_per_frame, metadata = next(video_iter)
|
||||
if self.expand_mm_tokens:
|
||||
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
|
||||
frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs]
|
||||
frame_strs = [
|
||||
f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||
for ts in timestamp_strs
|
||||
]
|
||||
video_str = " ".join(frame_strs)
|
||||
else:
|
||||
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||
@@ -760,7 +775,9 @@ class Gemma4Plugin(BasePlugin):
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
current_audio = next(audio_iter)
|
||||
if self.expand_mm_tokens:
|
||||
num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate)
|
||||
num_audio_tokens = processor._compute_audio_num_tokens(
|
||||
current_audio, processor.feature_extractor.sampling_rate
|
||||
)
|
||||
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
|
||||
else:
|
||||
audio_str = f"{boa_token}{audio_token}{eoa_token}"
|
||||
@@ -786,8 +803,14 @@ class Gemma4Plugin(BasePlugin):
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
# Pop metadata keys that must not be passed to the model.
|
||||
for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata",
|
||||
"_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"):
|
||||
for key in (
|
||||
"num_soft_tokens_per_image",
|
||||
"num_soft_tokens_per_video",
|
||||
"video_metadata",
|
||||
"_gemma4_fps_per_video",
|
||||
"_gemma4_frames_indices",
|
||||
"_gemma4_num_audio_soft_tokens",
|
||||
):
|
||||
mm_inputs.pop(key, None)
|
||||
|
||||
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
|
||||
@@ -1696,7 +1719,9 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||
original_fps = float(video_stream.average_rate)
|
||||
# for qwen3vl video timestamp calculation
|
||||
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
|
||||
frames_indices.append(
|
||||
[idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]
|
||||
) # hack usage when do_sample_frames=False
|
||||
container.seek(0)
|
||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||
if frame_idx in sample_indices:
|
||||
@@ -1715,7 +1740,12 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||
results.append(frames)
|
||||
|
||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
||||
return {
|
||||
"videos": results,
|
||||
"fps_per_video": fps_per_video,
|
||||
"durations": durations,
|
||||
"frames_indices": frames_indices,
|
||||
}
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
@@ -1830,8 +1860,15 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||
)
|
||||
video_metadata = [
|
||||
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
||||
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
|
||||
{
|
||||
"fps": getattr(processor, "video_fps", 2.0),
|
||||
"duration": duration,
|
||||
"total_num_frames": len(video),
|
||||
"frames_indices": sample_indices,
|
||||
}
|
||||
for video, duration, sample_indices in zip(
|
||||
videos["videos"], videos["durations"], videos["frames_indices"]
|
||||
)
|
||||
]
|
||||
mm_inputs.update(
|
||||
video_processor(
|
||||
@@ -1839,7 +1876,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
video_metadata=video_metadata,
|
||||
fps=getattr(processor, "video_fps", 2.0),
|
||||
return_metadata=True,
|
||||
do_sample_frames=False, # avoid changing frames_indices
|
||||
do_sample_frames=False, # avoid changing frames_indices
|
||||
)
|
||||
)
|
||||
temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2)
|
||||
|
||||
Reference in New Issue
Block a user