[data] fix MiniCPMV plugin (#6998)

* fix template

* fix bug in messages processing

Former-commit-id: cde479e47a51beb60ab555cdee083c1cdba0ead6
This commit is contained in:
Zhangchi Feng 2025-02-19 19:36:04 +08:00 committed by GitHub
parent b0bbacaacb
commit 1fcedf9af6

View File

@ -516,7 +516,6 @@ class MiniCPMVPlugin(BasePlugin):
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
mm_inputs = {}
audio_inputs = {}
audio_parts = []
if len(images) != 0 and len(videos) != 0:
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
@ -540,7 +539,6 @@ class MiniCPMVPlugin(BasePlugin):
num_video_tokens += 1
while AUDIO_PLACEHOLDER in content:
audio_parts.append(i)
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
num_audio_tokens += 1
@ -552,12 +550,12 @@ class MiniCPMVPlugin(BasePlugin):
mm_inputs = self._get_mm_inputs(images, [], [], processor)
if num_audio_tokens > 0:
audio_parts_ls = [audio_parts]
audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls, ret_phs=True)
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
if mm_inputs:
pattern = "(<image>./</image>)"
image_sizes = mm_inputs["image_sizes"]
idx = 0
for index, message in enumerate(messages):
text = message["content"]
image_tags = re.findall(pattern, text)
@ -568,23 +566,26 @@ class MiniCPMVPlugin(BasePlugin):
final_text
+ text_chunks[i]
+ image_processor.get_slice_image_placeholder(
image_sizes[0][i], i, max_slice_nums, use_image_id
image_sizes[0][idx], idx, max_slice_nums, use_image_id
)
)
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
if audio_inputs:
pattern = "(<audio>./</audio>)"
idx = 0
for index, message in enumerate(messages):
text = message["content"]
audio_tags = re.findall(pattern, text)
text_chunks = text.split(pattern)
final_text = ""
for i in range(len(audio_tags)):
audio_placeholder = audio_inputs["audio_phs"][0][i]
audio_placeholder = audio_inputs["audio_phs"][0][idx]
final_text = final_text + text_chunks[i] + audio_placeholder
idx += 1
final_text += text_chunks[-1]
messages[index]["content"] = final_text
@ -644,22 +645,24 @@ class MiniCPMVPlugin(BasePlugin):
mm_inputs.update(video_inputs)
if len(audios) != 0:
audio_parts_ls = kwargs.get("audio_parts_ls", None)
new_audios = []
for audio in audios:
if not isinstance(audio, np.ndarray):
audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0]
new_audios.append(audio)
audios_ls = []
idx = 0
for audio_parts in audio_parts_ls:
audios_ls.append(new_audios[idx : idx + len(audio_parts)])
idx += len(audio_parts)
if "valid_audio_nums_ls" in kwargs:
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
audios_ls = []
idx = 0
for valid_audio_nums in valid_audio_nums_ls:
audios_ls.append(new_audios[idx : idx + valid_audio_nums])
idx += valid_audio_nums
else:
audios_ls = [new_audios]
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
audios_ls,
audio_parts_ls,
chunk_input=True,
sampling_rate=16000,
)
@ -715,7 +718,7 @@ class MiniCPMVPlugin(BasePlugin):
# audio bound
audio_bounds_ls = []
spk_bounds_ls = []
audio_parts_ls = []
valid_audio_nums_ls = []
for input_ids, audiolen in zip(batch_ids, audlens):
input_ids_ = torch.tensor(input_ids)
@ -724,7 +727,7 @@ class MiniCPMVPlugin(BasePlugin):
assert len(audio_start_idx) == len(audio_end_idx)
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
audio_bounds_ls.append(audio_bounds)
audio_parts_ls.append(list(range(audiolen)))
valid_audio_nums_ls.append(audiolen)
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
@ -732,7 +735,7 @@ class MiniCPMVPlugin(BasePlugin):
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
spk_bounds_ls.append(spk_bounds)
audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls)
audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
mm_inputs.update(audio_inputs)
mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})