From 1fcedf9af643f10fe6de410b4fa60e9c15544217 Mon Sep 17 00:00:00 2001 From: Zhangchi Feng <64362896+BUAADreamer@users.noreply.github.com> Date: Wed, 19 Feb 2025 19:36:04 +0800 Subject: [PATCH] [data] fix MiniCPMV plugin (#6998) * fix template * fix bug in messages processing Former-commit-id: cde479e47a51beb60ab555cdee083c1cdba0ead6 --- src/llamafactory/data/mm_plugin.py | 35 ++++++++++++++++-------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 2c7a81c2..4947ff41 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -516,7 +516,6 @@ class MiniCPMVPlugin(BasePlugin): image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") mm_inputs = {} audio_inputs = {} - audio_parts = [] if len(images) != 0 and len(videos) != 0: raise ValueError("MiniCPM-V model does not support input images and videos at the same time.") @@ -540,7 +539,6 @@ class MiniCPMVPlugin(BasePlugin): num_video_tokens += 1 while AUDIO_PLACEHOLDER in content: - audio_parts.append(i) content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1) num_audio_tokens += 1 @@ -552,12 +550,12 @@ class MiniCPMVPlugin(BasePlugin): mm_inputs = self._get_mm_inputs(images, [], [], processor) if num_audio_tokens > 0: - audio_parts_ls = [audio_parts] - audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls, ret_phs=True) + audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True) if mm_inputs: pattern = "(./)" image_sizes = mm_inputs["image_sizes"] + idx = 0 for index, message in enumerate(messages): text = message["content"] image_tags = re.findall(pattern, text) @@ -568,23 +566,26 @@ class MiniCPMVPlugin(BasePlugin): final_text + text_chunks[i] + image_processor.get_slice_image_placeholder( - image_sizes[0][i], i, max_slice_nums, use_image_id + image_sizes[0][idx], idx, max_slice_nums, use_image_id ) ) + idx += 1 final_text += text_chunks[-1] messages[index]["content"] = final_text if audio_inputs: pattern = "()" + idx = 0 for index, message in enumerate(messages): text = message["content"] audio_tags = re.findall(pattern, text) text_chunks = text.split(pattern) final_text = "" for i in range(len(audio_tags)): - audio_placeholder = audio_inputs["audio_phs"][0][i] + audio_placeholder = audio_inputs["audio_phs"][0][idx] final_text = final_text + text_chunks[i] + audio_placeholder + idx += 1 final_text += text_chunks[-1] messages[index]["content"] = final_text @@ -644,22 +645,24 @@ class MiniCPMVPlugin(BasePlugin): mm_inputs.update(video_inputs) if len(audios) != 0: - audio_parts_ls = kwargs.get("audio_parts_ls", None) new_audios = [] for audio in audios: if not isinstance(audio, np.ndarray): audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0] new_audios.append(audio) - audios_ls = [] - idx = 0 - for audio_parts in audio_parts_ls: - audios_ls.append(new_audios[idx : idx + len(audio_parts)]) - idx += len(audio_parts) + if "valid_audio_nums_ls" in kwargs: + valid_audio_nums_ls = kwargs["valid_audio_nums_ls"] + audios_ls = [] + idx = 0 + for valid_audio_nums in valid_audio_nums_ls: + audios_ls.append(new_audios[idx : idx + valid_audio_nums]) + idx += valid_audio_nums + else: + audios_ls = [new_audios] audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract( audios_ls, - audio_parts_ls, chunk_input=True, sampling_rate=16000, ) @@ -715,7 +718,7 @@ class MiniCPMVPlugin(BasePlugin): # audio bound audio_bounds_ls = [] spk_bounds_ls = [] - audio_parts_ls = [] + valid_audio_nums_ls = [] for input_ids, audiolen in zip(batch_ids, audlens): input_ids_ = torch.tensor(input_ids) @@ -724,7 +727,7 @@ class MiniCPMVPlugin(BasePlugin): assert len(audio_start_idx) == len(audio_end_idx) audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)]) audio_bounds_ls.append(audio_bounds) - audio_parts_ls.append(list(range(audiolen))) + valid_audio_nums_ls.append(audiolen) spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0] spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0] @@ -732,7 +735,7 @@ class MiniCPMVPlugin(BasePlugin): spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)]) spk_bounds_ls.append(spk_bounds) - audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls) + audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls) mm_inputs.update(audio_inputs) mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})