mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
[data] fix qwen omni plugin (#9204)
Co-authored-by: kingsley <kingsleydodonow@gmail.com>
This commit is contained in:
@@ -1397,8 +1397,8 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
|
||||
@dataclass
|
||||
class Qwen2VLPlugin(BasePlugin):
|
||||
start_token: str = "<|vision_start|>"
|
||||
end_token: str = "<|vision_end|>"
|
||||
vision_bos_token: str = "<|vision_start|>"
|
||||
vision_eos_token: str = "<|vision_end|>"
|
||||
|
||||
@override
|
||||
def _preprocess_image(self, image: "ImageObject", **kwargs) -> "ImageObject":
|
||||
@@ -1515,14 +1515,18 @@ class Qwen2VLPlugin(BasePlugin):
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
while VIDEO_PLACEHOLDER in content:
|
||||
video_seqlen = video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"{self.start_token}{self.video_token * video_seqlen}{self.end_token}", 1
|
||||
VIDEO_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
@@ -1611,7 +1615,9 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
image_grid_thw[num_image_tokens].prod() // image_merge_length if self.expand_mm_tokens else 1
|
||||
)
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"{self.start_token}{self.image_token * image_seqlen}{self.end_token}", 1
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
@@ -1630,11 +1636,14 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
||||
else 1
|
||||
)
|
||||
timestamp_sec = timestamps[frame_index]
|
||||
frame_structure = f"<{timestamp_sec:.1f} seconds>{self.start_token}{self.video_token * video_seqlen}{self.end_token}"
|
||||
frame_structure = (
|
||||
f"<{timestamp_sec:.1f} seconds>"
|
||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}"
|
||||
)
|
||||
video_structure += frame_structure
|
||||
|
||||
if not self.expand_mm_tokens:
|
||||
video_structure = f"{self.start_token}{self.video_token}{self.end_token}"
|
||||
video_structure = f"{self.vision_bos_token}{self.video_token}{self.vision_eos_token}"
|
||||
|
||||
content = content.replace(VIDEO_PLACEHOLDER, video_structure, 1)
|
||||
num_video_tokens += 1
|
||||
@@ -1774,7 +1783,11 @@ class GLM4VPlugin(Qwen2VLPlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
audio_bos_token: str = "<|audio_start|>"
|
||||
audio_eos_token: str = "<|audio_end|>"
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
@@ -1861,7 +1874,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
IMAGE_PLACEHOLDER, f"<|vision_bos|>{self.image_token * image_seqlen}<|vision_eos|>", 1
|
||||
IMAGE_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.image_token * image_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_image_tokens += 1
|
||||
|
||||
@@ -1898,7 +1913,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
video_chunk_indices = processor.get_chunked_index(video_t_index, t_ntoken_per_chunk)
|
||||
audio_chunk_indices = processor.get_chunked_index(audio_t_index, t_ntoken_per_chunk)
|
||||
placeholder_string = ""
|
||||
placeholder_string += "<|vision_bos|>" + "<|audio_bos|>"
|
||||
placeholder_string += self.vision_bos_token + self.audio_bos_token
|
||||
for j in range(max(len(video_chunk_indices), len(audio_chunk_indices))):
|
||||
video_chunk_index = video_chunk_indices[j] if j < len(video_chunk_indices) else None
|
||||
audio_chunk_index = audio_chunk_indices[j] if j < len(audio_chunk_indices) else None
|
||||
@@ -1908,7 +1923,7 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
if audio_chunk_index is not None:
|
||||
placeholder_string += self.audio_token * (audio_chunk_index[1] - audio_chunk_index[0])
|
||||
|
||||
placeholder_string += "<|audio_eos|>" + "<|vision_eos|>"
|
||||
placeholder_string += self.audio_eos_token + self.vision_eos_token
|
||||
content = content.replace(VIDEO_PLACEHOLDER, placeholder_string, 1)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "", 1)
|
||||
num_audio_tokens += 1
|
||||
@@ -1917,7 +1932,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
audio_seqlen = audio_lengths[num_audio_tokens] if self.expand_mm_tokens else 1
|
||||
content = content.replace(
|
||||
AUDIO_PLACEHOLDER, f"<|audio_bos|>{self.audio_token * audio_seqlen}<|audio_eos|>", 1
|
||||
AUDIO_PLACEHOLDER,
|
||||
f"{self.audio_bos_token}{self.audio_token * audio_seqlen}{self.audio_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_audio_tokens += 1
|
||||
|
||||
@@ -1926,7 +1943,9 @@ class Qwen2OmniPlugin(Qwen2VLPlugin):
|
||||
video_grid_thw[num_video_tokens].prod() // merge_length if self.expand_mm_tokens else 1
|
||||
)
|
||||
content = content.replace(
|
||||
VIDEO_PLACEHOLDER, f"<|vision_bos|>{self.video_token * video_seqlen}<|vision_eos|>", 1
|
||||
VIDEO_PLACEHOLDER,
|
||||
f"{self.vision_bos_token}{self.video_token * video_seqlen}{self.vision_eos_token}",
|
||||
1,
|
||||
)
|
||||
num_video_tokens += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user