mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
[model] fix mllama any image (#6637)
* fix mllama any image * reorder classes Former-commit-id: 98189c8e4d70bf5f8ee83852a023ed27dfc96900
This commit is contained in:
parent
5e699458e5
commit
8f73c75c16
@ -567,6 +567,85 @@ class MiniCPMVPlugin(BasePlugin):
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
**kwargs,
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||
|
||||
Returns:
|
||||
pixel_values: tensor with shape
|
||||
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
||||
For example, (2, 1, 4, 3, 560, 560).
|
||||
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
imglens: List[int] = kwargs["imglens"]
|
||||
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
||||
batch_images = []
|
||||
for image_length in imglens:
|
||||
batch_images.append(images[:image_length])
|
||||
images = images[image_length:]
|
||||
|
||||
return image_processor(batch_images, return_tensors="pt")
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor, imglens=imglens)
|
||||
num_tiles = mm_inputs.pop("num_tiles")
|
||||
image_token_id = getattr(processor, "image_token_id")
|
||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||
cross_attention_token_mask = [
|
||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||
]
|
||||
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
||||
convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask,
|
||||
num_tiles=num_tiles,
|
||||
max_num_tiles=max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in batch_ids),
|
||||
)
|
||||
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
@ -868,92 +947,17 @@ class VideoLlavaPlugin(BasePlugin):
|
||||
return self._get_mm_inputs(images, videos, processor)
|
||||
|
||||
|
||||
class MllamaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> List[Dict[str, str]]:
|
||||
self._validate_input(images, videos)
|
||||
num_image_tokens = 0
|
||||
messages = deepcopy(messages)
|
||||
for message in messages:
|
||||
content = message["content"]
|
||||
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def _get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
processor: "ProcessorMixin",
|
||||
) -> Dict[str, "torch.Tensor"]:
|
||||
r"""
|
||||
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||
|
||||
Returns:
|
||||
pixel_values: tensor with shape
|
||||
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
||||
For example, (2, 1, 4, 3, 560, 560).
|
||||
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||
"""
|
||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
||||
return image_processor([[image] for image in images], return_tensors="pt")
|
||||
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
videos: Sequence["VideoInput"],
|
||||
imglens: Sequence[int],
|
||||
vidlens: Sequence[int],
|
||||
batch_ids: Sequence[List[int]],
|
||||
processor: Optional["ProcessorMixin"],
|
||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||
self._validate_input(images, videos)
|
||||
if len(images) != len(batch_ids):
|
||||
raise ValueError("Mllama only supports one image per sample.")
|
||||
|
||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||
num_tiles = mm_inputs.pop("num_tiles")
|
||||
image_token_id = getattr(processor, "image_token_id")
|
||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||
cross_attention_token_mask = [
|
||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||
]
|
||||
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
||||
convert_sparse_cross_attention_mask_to_dense(
|
||||
cross_attention_token_mask,
|
||||
num_tiles=num_tiles,
|
||||
max_num_tiles=max_image_tiles,
|
||||
length=max(len(input_ids) for input_ids in batch_ids),
|
||||
)
|
||||
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
||||
return mm_inputs
|
||||
|
||||
|
||||
PLUGINS = {
|
||||
"base": BasePlugin,
|
||||
"llava": LlavaPlugin,
|
||||
"llava_next": LlavaNextPlugin,
|
||||
"llava_next_video": LlavaNextVideoPlugin,
|
||||
"minicpm_v": MiniCPMVPlugin,
|
||||
"mllama": MllamaPlugin,
|
||||
"paligemma": PaliGemmaPlugin,
|
||||
"pixtral": PixtralPlugin,
|
||||
"qwen2_vl": Qwen2vlPlugin,
|
||||
"video_llava": VideoLlavaPlugin,
|
||||
"mllama": MllamaPlugin,
|
||||
}
|
||||
|
||||
|
||||
|
@ -759,8 +759,9 @@ _register_template(
|
||||
)
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
@ -786,8 +787,9 @@ _register_template(
|
||||
)
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
@ -838,8 +840,9 @@ _register_template(
|
||||
)
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
@ -1130,15 +1133,18 @@ _register_template(
|
||||
)
|
||||
]
|
||||
),
|
||||
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||
format_observation=StringFormatter(
|
||||
slots=[
|
||||
(
|
||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
]
|
||||
),
|
||||
format_tools=ToolFormatter(tool_format="llama3"),
|
||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||
default_system=(
|
||||
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
|
||||
@ -1147,7 +1153,7 @@ _register_template(
|
||||
"After completing your thoughts, you then provide a detailed explanation of the solution process "
|
||||
"in your response."
|
||||
),
|
||||
stop_words=["<|eot_id|>"],
|
||||
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user