mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
[data] fix minicpmo vllm infer (#7870)
This commit is contained in:
parent
035e98035c
commit
fcca3b0b0d
@ -898,115 +898,6 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MiniCPMVPlugin(BasePlugin):
|
class MiniCPMVPlugin(BasePlugin):
|
||||||
@override
|
|
||||||
def process_messages(
|
|
||||||
self,
|
|
||||||
messages: list[dict[str, str]],
|
|
||||||
images: list["ImageInput"],
|
|
||||||
videos: list["VideoInput"],
|
|
||||||
audios: list["AudioInput"],
|
|
||||||
processor: Optional["MMProcessor"],
|
|
||||||
) -> list[dict[str, str]]:
|
|
||||||
self._validate_input(processor, images, videos, audios)
|
|
||||||
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
|
||||||
messages = deepcopy(messages)
|
|
||||||
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
|
||||||
mm_inputs = {}
|
|
||||||
audio_inputs = {}
|
|
||||||
if len(images) != 0 and len(videos) != 0:
|
|
||||||
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
|
||||||
|
|
||||||
if len(videos) != 0:
|
|
||||||
max_slice_nums = 2
|
|
||||||
use_image_id = False
|
|
||||||
mm_inputs = self._get_mm_inputs([], videos, [], processor)
|
|
||||||
else:
|
|
||||||
max_slice_nums = image_processor.max_slice_nums
|
|
||||||
use_image_id = image_processor.use_image_id
|
|
||||||
|
|
||||||
for i, message in enumerate(messages):
|
|
||||||
content = message["content"]
|
|
||||||
while IMAGE_PLACEHOLDER in content:
|
|
||||||
if num_image_tokens >= len(images):
|
|
||||||
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
|
||||||
num_image_tokens += 1
|
|
||||||
|
|
||||||
while VIDEO_PLACEHOLDER in content:
|
|
||||||
if num_video_tokens >= len(videos):
|
|
||||||
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
|
||||||
num_video_tokens += 1
|
|
||||||
|
|
||||||
while AUDIO_PLACEHOLDER in content:
|
|
||||||
if num_audio_tokens >= len(audios):
|
|
||||||
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
|
||||||
num_audio_tokens += 1
|
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
|
|
||||||
"{{audio}}", "(<audio>./</audio>)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_image_tokens > 0:
|
|
||||||
mm_inputs = self._get_mm_inputs(images, [], [], processor)
|
|
||||||
|
|
||||||
if num_audio_tokens > 0:
|
|
||||||
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
|
|
||||||
|
|
||||||
if mm_inputs:
|
|
||||||
pattern = "(<image>./</image>)"
|
|
||||||
image_sizes = mm_inputs["image_sizes"]
|
|
||||||
idx = 0
|
|
||||||
for index, message in enumerate(messages):
|
|
||||||
text = message["content"]
|
|
||||||
image_tags = re.findall(pattern, text)
|
|
||||||
text_chunks = text.split(pattern)
|
|
||||||
final_text = ""
|
|
||||||
for i in range(len(image_tags)):
|
|
||||||
final_text = (
|
|
||||||
final_text
|
|
||||||
+ text_chunks[i]
|
|
||||||
+ image_processor.get_slice_image_placeholder(
|
|
||||||
image_sizes[0][idx], idx, max_slice_nums, use_image_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
final_text += text_chunks[-1]
|
|
||||||
messages[index]["content"] = final_text
|
|
||||||
|
|
||||||
if audio_inputs:
|
|
||||||
pattern = "(<audio>./</audio>)"
|
|
||||||
idx = 0
|
|
||||||
for index, message in enumerate(messages):
|
|
||||||
text = message["content"]
|
|
||||||
audio_tags = re.findall(pattern, text)
|
|
||||||
text_chunks = text.split(pattern)
|
|
||||||
final_text = ""
|
|
||||||
for i in range(len(audio_tags)):
|
|
||||||
audio_placeholder = audio_inputs["audio_phs"][0][idx]
|
|
||||||
final_text = final_text + text_chunks[i] + audio_placeholder
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
final_text += text_chunks[-1]
|
|
||||||
messages[index]["content"] = final_text
|
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(videos) != num_video_tokens:
|
|
||||||
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
if len(audios) != num_audio_tokens:
|
|
||||||
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -1077,6 +968,114 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
images: list["ImageInput"],
|
||||||
|
videos: list["VideoInput"],
|
||||||
|
audios: list["AudioInput"],
|
||||||
|
processor: Optional["MMProcessor"],
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
self._validate_input(processor, images, videos, audios)
|
||||||
|
num_image_tokens, num_video_tokens, num_audio_tokens = 0, 0, 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
image_processor: BaseImageProcessor = getattr(processor, "image_processor")
|
||||||
|
mm_inputs, audio_inputs = {}, {}
|
||||||
|
if len(images) != 0 and len(videos) != 0:
|
||||||
|
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
max_slice_nums = 2
|
||||||
|
use_image_id = False
|
||||||
|
mm_inputs = self._get_mm_inputs([], videos, [], processor)
|
||||||
|
else:
|
||||||
|
max_slice_nums = image_processor.max_slice_nums
|
||||||
|
use_image_id = image_processor.use_image_id
|
||||||
|
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
if num_image_tokens >= len(images):
|
||||||
|
raise ValueError(f"`len(images)` is less than the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
|
num_image_tokens += 1
|
||||||
|
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
if num_video_tokens >= len(videos):
|
||||||
|
raise ValueError(f"`len(videos)` is less than the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
video_seqlen = len(mm_inputs["pixel_values"][num_video_tokens]) if self.expand_mm_tokens else 1
|
||||||
|
content = content.replace(VIDEO_PLACEHOLDER, "{{image}}" * video_seqlen, 1)
|
||||||
|
num_video_tokens += 1
|
||||||
|
|
||||||
|
while AUDIO_PLACEHOLDER in content:
|
||||||
|
if num_audio_tokens >= len(audios):
|
||||||
|
raise ValueError(f"`len(audios)` is less than the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
||||||
|
num_audio_tokens += 1
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", "(<image>./</image>)").replace(
|
||||||
|
"{{audio}}", "(<audio>./</audio>)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(images):
|
||||||
|
mm_inputs = self._get_mm_inputs(images, [], [], processor)
|
||||||
|
|
||||||
|
if len(audios):
|
||||||
|
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
|
||||||
|
|
||||||
|
if self.expand_mm_tokens and mm_inputs:
|
||||||
|
pattern = "(<image>./</image>)"
|
||||||
|
image_sizes = mm_inputs["image_sizes"]
|
||||||
|
idx = 0
|
||||||
|
for index, message in enumerate(messages):
|
||||||
|
text = message["content"]
|
||||||
|
image_tags = re.findall(pattern, text)
|
||||||
|
text_chunks = text.split(pattern)
|
||||||
|
final_text = ""
|
||||||
|
for i in range(len(image_tags)):
|
||||||
|
final_text = (
|
||||||
|
final_text
|
||||||
|
+ text_chunks[i]
|
||||||
|
+ image_processor.get_slice_image_placeholder(
|
||||||
|
image_sizes[0][idx], idx, max_slice_nums, use_image_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
final_text += text_chunks[-1]
|
||||||
|
messages[index]["content"] = final_text
|
||||||
|
|
||||||
|
if self.expand_mm_tokens and audio_inputs:
|
||||||
|
pattern = "(<audio>./</audio>)"
|
||||||
|
idx = 0
|
||||||
|
for index, message in enumerate(messages):
|
||||||
|
text = message["content"]
|
||||||
|
audio_tags = re.findall(pattern, text)
|
||||||
|
text_chunks = text.split(pattern)
|
||||||
|
final_text = ""
|
||||||
|
for i in range(len(audio_tags)):
|
||||||
|
audio_placeholder = audio_inputs["audio_phs"][0][idx]
|
||||||
|
final_text = final_text + text_chunks[i] + audio_placeholder
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
final_text += text_chunks[-1]
|
||||||
|
messages[index]["content"] = final_text
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
if len(audios) != num_audio_tokens:
|
||||||
|
raise ValueError(f"The number of audios does not match the number of {AUDIO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user