mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
[data] fix MiniCPMV plugin (#6998)
* fix template * fix bug in messages processing Former-commit-id: cde479e47a51beb60ab555cdee083c1cdba0ead6
This commit is contained in:
parent
b0bbacaacb
commit
1fcedf9af6
@ -516,7 +516,6 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
mm_inputs = {}
|
||||
audio_inputs = {}
|
||||
audio_parts = []
|
||||
if len(images) != 0 and len(videos) != 0:
|
||||
raise ValueError("MiniCPM-V model does not support input images and videos at the same time.")
|
||||
|
||||
@ -540,7 +539,6 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
num_video_tokens += 1
|
||||
|
||||
while AUDIO_PLACEHOLDER in content:
|
||||
audio_parts.append(i)
|
||||
content = content.replace(AUDIO_PLACEHOLDER, "{{audio}}", 1)
|
||||
num_audio_tokens += 1
|
||||
|
||||
@ -552,12 +550,12 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
mm_inputs = self._get_mm_inputs(images, [], [], processor)
|
||||
|
||||
if num_audio_tokens > 0:
|
||||
audio_parts_ls = [audio_parts]
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls, ret_phs=True)
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor, ret_phs=True)
|
||||
|
||||
if mm_inputs:
|
||||
pattern = "(<image>./</image>)"
|
||||
image_sizes = mm_inputs["image_sizes"]
|
||||
idx = 0
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
image_tags = re.findall(pattern, text)
|
||||
@ -568,23 +566,26 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
final_text
|
||||
+ text_chunks[i]
|
||||
+ image_processor.get_slice_image_placeholder(
|
||||
image_sizes[0][i], i, max_slice_nums, use_image_id
|
||||
image_sizes[0][idx], idx, max_slice_nums, use_image_id
|
||||
)
|
||||
)
|
||||
idx += 1
|
||||
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
|
||||
if audio_inputs:
|
||||
pattern = "(<audio>./</audio>)"
|
||||
idx = 0
|
||||
for index, message in enumerate(messages):
|
||||
text = message["content"]
|
||||
audio_tags = re.findall(pattern, text)
|
||||
text_chunks = text.split(pattern)
|
||||
final_text = ""
|
||||
for i in range(len(audio_tags)):
|
||||
audio_placeholder = audio_inputs["audio_phs"][0][i]
|
||||
audio_placeholder = audio_inputs["audio_phs"][0][idx]
|
||||
final_text = final_text + text_chunks[i] + audio_placeholder
|
||||
idx += 1
|
||||
|
||||
final_text += text_chunks[-1]
|
||||
messages[index]["content"] = final_text
|
||||
@ -644,22 +645,24 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
if len(audios) != 0:
|
||||
audio_parts_ls = kwargs.get("audio_parts_ls", None)
|
||||
new_audios = []
|
||||
for audio in audios:
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0]
|
||||
new_audios.append(audio)
|
||||
|
||||
audios_ls = []
|
||||
idx = 0
|
||||
for audio_parts in audio_parts_ls:
|
||||
audios_ls.append(new_audios[idx : idx + len(audio_parts)])
|
||||
idx += len(audio_parts)
|
||||
if "valid_audio_nums_ls" in kwargs:
|
||||
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
|
||||
audios_ls = []
|
||||
idx = 0
|
||||
for valid_audio_nums in valid_audio_nums_ls:
|
||||
audios_ls.append(new_audios[idx : idx + valid_audio_nums])
|
||||
idx += valid_audio_nums
|
||||
else:
|
||||
audios_ls = [new_audios]
|
||||
|
||||
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
||||
audios_ls,
|
||||
audio_parts_ls,
|
||||
chunk_input=True,
|
||||
sampling_rate=16000,
|
||||
)
|
||||
@ -715,7 +718,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
# audio bound
|
||||
audio_bounds_ls = []
|
||||
spk_bounds_ls = []
|
||||
audio_parts_ls = []
|
||||
valid_audio_nums_ls = []
|
||||
|
||||
for input_ids, audiolen in zip(batch_ids, audlens):
|
||||
input_ids_ = torch.tensor(input_ids)
|
||||
@ -724,7 +727,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
assert len(audio_start_idx) == len(audio_end_idx)
|
||||
audio_bounds = torch.hstack([(audio_start_idx + 1).unsqueeze(-1), audio_end_idx.unsqueeze(-1)])
|
||||
audio_bounds_ls.append(audio_bounds)
|
||||
audio_parts_ls.append(list(range(audiolen)))
|
||||
valid_audio_nums_ls.append(audiolen)
|
||||
|
||||
spk_start_idx = torch.where(input_ids_ == processor.tokenizer.spk_start_id)[0]
|
||||
spk_end_idx = torch.where(input_ids_ == processor.tokenizer.spk_end_id)[0]
|
||||
@ -732,7 +735,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
spk_bounds = torch.hstack([(spk_start_idx + 1).unsqueeze(-1), spk_end_idx.unsqueeze(-1)])
|
||||
spk_bounds_ls.append(spk_bounds)
|
||||
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor, audio_parts_ls=audio_parts_ls)
|
||||
audio_inputs = self._get_mm_inputs([], [], audios, processor, valid_audio_nums_ls=valid_audio_nums_ls)
|
||||
mm_inputs.update(audio_inputs)
|
||||
mm_inputs.update({"audio_bounds": audio_bounds_ls, "spk_bounds": spk_bounds_ls})
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user