mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
[data] fix mllama (#7053)
* fix mllama * fix test Former-commit-id: 76314e6ad1ecaa44fcae4375dd0abf4ebaf1f924
This commit is contained in:
parent
ca78ba964d
commit
dca5fe14c2
@ -82,6 +82,7 @@ class MMPluginMixin:
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
processor: Optional["ProcessorMixin"],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
audios: Sequence["AudioInput"],
|
||||
@ -89,6 +90,8 @@ class MMPluginMixin:
|
||||
r"""
|
||||
Validates if this model accepts the input modalities.
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor", None)
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
if len(images) != 0 and self.image_token is None:
|
||||
raise ValueError(
|
||||
"This model does not support image input. Please check whether the correct `template` is used."
|
||||
@ -104,6 +107,15 @@ class MMPluginMixin:
|
||||
"This model does not support audio input. Please check whether the correct `template` is used."
|
||||
)
|
||||
|
||||
if self.image_token is not None and processor is None:
|
||||
raise ValueError("Processor was not found, please check and update your processor config.")
|
||||
|
||||
if self.image_token is not None and image_processor is None:
|
||||
raise ValueError("Image processor was not found, please check and update your processor config.")
|
||||
|
||||
if self.audio_token is not None and feature_extractor is None:
|
||||
raise ValueError("Audio feature extractor was not found, please check and update your processor config.")
|
||||
|
||||
def _preprocess_image(
|
||||
self, image: "ImageObject", image_max_pixels: int, image_min_pixels: int, **kwargs
|
||||
) -> "ImageObject":
|
||||
@ -275,7 +287,7 @@ class BasePlugin(MMPluginMixin):
|
||||
r"""
|
||||
Pre-processes input messages before tokenization for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return messages
|
||||
|
||||
def process_token_ids(
|
||||
@ -291,7 +303,7 @@ class BasePlugin(MMPluginMixin):
|
||||
r"""
|
||||
Pre-processes token ids after tokenization for VLMs.
|
||||
"""
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return input_ids, labels
|
||||
|
||||
def get_mm_inputs(
|
||||
@ -317,7 +329,7 @@ class BasePlugin(MMPluginMixin):
|
||||
batch_ids: token ids of input samples, shape (batch_size, seq_len)
|
||||
processor: a processor for pre-processing images and videos
|
||||
"""
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return {}
|
||||
|
||||
|
||||
@ -332,7 +344,7 @@ class LlavaPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
image_seqlen = getattr(processor, "image_seqlen") if self.expand_mm_tokens else 1
|
||||
messages = deepcopy(messages)
|
||||
@ -361,7 +373,7 @@ class LlavaPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
@ -376,7 +388,7 @@ class LlavaNextPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@ -417,7 +429,7 @@ class LlavaNextPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
@ -432,7 +444,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@ -493,7 +505,7 @@ class LlavaNextVideoPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
@ -508,7 +520,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
num_video_tokens = 0
|
||||
num_audio_tokens = 0
|
||||
@ -611,6 +623,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
**kwargs,
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
feature_extractor: "SequenceFeatureExtractor" = getattr(processor, "feature_extractor", None)
|
||||
mm_inputs = {}
|
||||
if len(images) != 0:
|
||||
images = self._regularize_images(
|
||||
@ -645,21 +658,19 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
mm_inputs.update(video_inputs)
|
||||
|
||||
if len(audios) != 0:
|
||||
new_audios = []
|
||||
for audio in audios:
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = librosa.load(audio, sr=processor.feature_extractor.sampling_rate)[0]
|
||||
new_audios.append(audio)
|
||||
|
||||
audios = self._regularize_audios(
|
||||
audios,
|
||||
sampling_rate=getattr(feature_extractor, "sampling_rate", 16000),
|
||||
)
|
||||
if "valid_audio_nums_ls" in kwargs:
|
||||
valid_audio_nums_ls = kwargs["valid_audio_nums_ls"]
|
||||
audios_ls = []
|
||||
idx = 0
|
||||
for valid_audio_nums in valid_audio_nums_ls:
|
||||
audios_ls.append(new_audios[idx : idx + valid_audio_nums])
|
||||
audios_ls.append(audios[idx : idx + valid_audio_nums])
|
||||
idx += valid_audio_nums
|
||||
else:
|
||||
audios_ls = [new_audios]
|
||||
audios_ls = [audios]
|
||||
|
||||
audio_features, audio_feature_lens, audio_phs = processor.audio_feature_extract(
|
||||
audios_ls,
|
||||
@ -685,7 +696,7 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
# image bound
|
||||
image_bounds_list = []
|
||||
valid_image_nums_ls = []
|
||||
@ -753,7 +764,7 @@ class MllamaPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
@ -787,17 +798,21 @@ class MllamaPlugin(BasePlugin):
|
||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)
|
||||
batch_images = []
|
||||
for image_length in imglens:
|
||||
batch_images.append(images[:image_length])
|
||||
images = images[image_length:]
|
||||
mm_inputs = {}
|
||||
if len(images) > 0:
|
||||
images = self._regularize_images(
|
||||
images,
|
||||
image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
|
||||
image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
|
||||
)
|
||||
batch_images = []
|
||||
for image_length in imglens:
|
||||
batch_images.append(images[:image_length])
|
||||
images = images[image_length:]
|
||||
|
||||
return image_processor(batch_images, return_tensors="pt")
|
||||
mm_inputs.update(image_processor(batch_images, return_tensors="pt"))
|
||||
|
||||
return mm_inputs
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
@ -811,22 +826,24 @@ class MllamaPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor, imglens)
|
||||
num_tiles = mm_inputs.pop("num_tiles")
|
||||
image_token_id = getattr(processor, "image_token_id")
|
||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||
cross_attention_token_mask = [
|
||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||
]
|
||||
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
||||
convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask,
|
||||
num_tiles=num_tiles,
|
||||
max_num_tiles=max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in batch_ids),
|
||||
)
|
||||
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
||||
if mm_inputs:
|
||||
num_tiles = mm_inputs.pop("num_tiles")
|
||||
image_token_id = getattr(processor, "image_token_id")
|
||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||
cross_attention_token_mask = [
|
||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||
]
|
||||
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
||||
convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask,
|
||||
num_tiles=num_tiles,
|
||||
max_num_tiles=max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in batch_ids),
|
||||
)
|
||||
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
||||
|
||||
return mm_inputs
|
||||
|
||||
|
||||
@ -841,7 +858,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
@ -868,7 +885,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Tuple[List[int], Optional[List[int]]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_images = len(images)
|
||||
image_seqlen = num_images * getattr(processor, "image_seqlen") if self.expand_mm_tokens else 0 # skip mm token
|
||||
image_token_id = tokenizer.convert_tokens_to_ids(self.image_token)
|
||||
@ -890,7 +907,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
seqlens = [len(input_ids) for input_ids in batch_ids]
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs["token_type_ids"] = _get_paligemma_token_type_ids(imglens, seqlens, processor)
|
||||
@ -908,7 +925,7 @@ class PixtralPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
patch_size = getattr(processor, "patch_size")
|
||||
image_token = getattr(processor, "image_token")
|
||||
image_break_token = getattr(processor, "image_break_token")
|
||||
@ -956,7 +973,7 @@ class PixtralPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
mm_inputs.pop("image_sizes", None)
|
||||
return mm_inputs
|
||||
@ -973,7 +990,7 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
bos_token: str = getattr(processor, "audio_bos_token")
|
||||
eos_token: str = getattr(processor, "audio_eos_token")
|
||||
mm_inputs = self._get_mm_inputs([], [], audios, processor)
|
||||
@ -1015,7 +1032,7 @@ class Qwen2AudioPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
@ -1105,7 +1122,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
merge_length: int = getattr(image_processor, "merge_size") ** 2
|
||||
if self.expand_mm_tokens:
|
||||
@ -1162,7 +1179,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
fps_per_video = mm_inputs.pop("fps_per_video", [])
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
@ -1183,7 +1200,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
audios: Sequence["AudioInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
num_image_tokens, num_video_tokens = 0, 0
|
||||
messages = deepcopy(messages)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||
@ -1241,7 +1258,7 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos, audios)
|
||||
self._validate_input(processor, images, videos, audios)
|
||||
return self._get_mm_inputs(images, videos, audios, processor)
|
||||
|
||||
|
||||
|
@ -103,15 +103,17 @@ def _check_plugin(
|
||||
expected_no_mm_inputs: Dict[str, Any] = {},
|
||||
) -> None:
|
||||
# test mm_messages
|
||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
|
||||
expected_mm_inputs,
|
||||
)
|
||||
if plugin.__class__.__name__ != "BasePlugin":
|
||||
assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
|
||||
expected_input_ids,
|
||||
expected_labels,
|
||||
)
|
||||
_is_close(
|
||||
plugin.get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, IMGLENS, NO_VIDLENS, NO_AUDLENS, BATCH_IDS, processor),
|
||||
expected_mm_inputs,
|
||||
)
|
||||
|
||||
# test text_messages
|
||||
assert plugin.process_messages(TEXT_MESSAGES, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == TEXT_MESSAGES
|
||||
assert plugin.process_token_ids(INPUT_IDS, LABELS, NO_IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
|
||||
@ -128,7 +130,7 @@ def _check_plugin(
|
||||
|
||||
def test_base_plugin():
|
||||
tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA)
|
||||
base_plugin = get_mm_plugin(name="base", image_token="<image>")
|
||||
base_plugin = get_mm_plugin(name="base")
|
||||
check_inputs = {"plugin": base_plugin, **tokenizer_module}
|
||||
_check_plugin(**check_inputs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user