mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
[model] fix mllama any image (#6637)
* fix mllama any image * reorder classes Former-commit-id: 98189c8e4d70bf5f8ee83852a023ed27dfc96900
This commit is contained in:
parent
5e699458e5
commit
8f73c75c16
@ -567,6 +567,85 @@ class MiniCPMVPlugin(BasePlugin):
|
|||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
class MllamaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
||||||
|
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
r"""
|
||||||
|
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pixel_values: tensor with shape
|
||||||
|
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
||||||
|
For example, (2, 1, 4, 3, 560, 560).
|
||||||
|
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
||||||
|
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
||||||
|
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
||||||
|
"""
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
imglens: List[int] = kwargs["imglens"]
|
||||||
|
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
||||||
|
batch_images = []
|
||||||
|
for image_length in imglens:
|
||||||
|
batch_images.append(images[:image_length])
|
||||||
|
images = images[image_length:]
|
||||||
|
|
||||||
|
return image_processor(batch_images, return_tensors="pt")
|
||||||
|
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor, imglens=imglens)
|
||||||
|
num_tiles = mm_inputs.pop("num_tiles")
|
||||||
|
image_token_id = getattr(processor, "image_token_id")
|
||||||
|
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
||||||
|
cross_attention_token_mask = [
|
||||||
|
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
||||||
|
]
|
||||||
|
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
||||||
|
convert_sparse_cross_attention_mask_to_dense(
|
||||||
|
cross_attention_token_mask,
|
||||||
|
num_tiles=num_tiles,
|
||||||
|
max_num_tiles=max_image_tiles,
|
||||||
|
length=max(len(input_ids) for input_ids in batch_ids),
|
||||||
|
)
|
||||||
|
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaPlugin(BasePlugin):
|
class PaliGemmaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -868,92 +947,17 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
return self._get_mm_inputs(images, videos, processor)
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
|
||||||
class MllamaPlugin(BasePlugin):
|
|
||||||
@override
|
|
||||||
def process_messages(
|
|
||||||
self,
|
|
||||||
messages: Sequence[Dict[str, str]],
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
self._validate_input(images, videos)
|
|
||||||
num_image_tokens = 0
|
|
||||||
messages = deepcopy(messages)
|
|
||||||
for message in messages:
|
|
||||||
content = message["content"]
|
|
||||||
num_image_tokens += content.count(IMAGE_PLACEHOLDER)
|
|
||||||
message["content"] = content.replace(IMAGE_PLACEHOLDER, self.image_token)
|
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@override
|
|
||||||
def _get_mm_inputs(
|
|
||||||
self,
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
processor: "ProcessorMixin",
|
|
||||||
) -> Dict[str, "torch.Tensor"]:
|
|
||||||
r"""
|
|
||||||
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
pixel_values: tensor with shape
|
|
||||||
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
|
|
||||||
For example, (2, 1, 4, 3, 560, 560).
|
|
||||||
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
|
|
||||||
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
|
|
||||||
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
|
|
||||||
"""
|
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
|
||||||
images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512))
|
|
||||||
return image_processor([[image] for image in images], return_tensors="pt")
|
|
||||||
|
|
||||||
def get_mm_inputs(
|
|
||||||
self,
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
imglens: Sequence[int],
|
|
||||||
vidlens: Sequence[int],
|
|
||||||
batch_ids: Sequence[List[int]],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
||||||
self._validate_input(images, videos)
|
|
||||||
if len(images) != len(batch_ids):
|
|
||||||
raise ValueError("Mllama only supports one image per sample.")
|
|
||||||
|
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
|
||||||
num_tiles = mm_inputs.pop("num_tiles")
|
|
||||||
image_token_id = getattr(processor, "image_token_id")
|
|
||||||
max_image_tiles = getattr(processor.image_processor, "max_image_tiles")
|
|
||||||
cross_attention_token_mask = [
|
|
||||||
get_cross_attention_token_mask(input_ids, image_token_id) for input_ids in batch_ids
|
|
||||||
]
|
|
||||||
mm_inputs["cross_attention_mask"] = torch.from_numpy(
|
|
||||||
convert_sparse_cross_attention_mask_to_dense(
|
|
||||||
cross_attention_token_mask,
|
|
||||||
num_tiles=num_tiles,
|
|
||||||
max_num_tiles=max_image_tiles,
|
|
||||||
length=max(len(input_ids) for input_ids in batch_ids),
|
|
||||||
)
|
|
||||||
) # shape: (batch_size, length, max_num_images, max_num_tiles)
|
|
||||||
return mm_inputs
|
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
"llava_next_video": LlavaNextVideoPlugin,
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
"minicpm_v": MiniCPMVPlugin,
|
"minicpm_v": MiniCPMVPlugin,
|
||||||
|
"mllama": MllamaPlugin,
|
||||||
"paligemma": PaliGemmaPlugin,
|
"paligemma": PaliGemmaPlugin,
|
||||||
"pixtral": PixtralPlugin,
|
"pixtral": PixtralPlugin,
|
||||||
"qwen2_vl": Qwen2vlPlugin,
|
"qwen2_vl": Qwen2vlPlugin,
|
||||||
"video_llava": VideoLlavaPlugin,
|
"video_llava": VideoLlavaPlugin,
|
||||||
"mllama": MllamaPlugin,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -759,8 +759,9 @@ _register_template(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
@ -786,8 +787,9 @@ _register_template(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
@ -838,8 +840,9 @@ _register_template(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
format_function=FunctionFormatter(slots=["{{content}}", "<|eot_id|>"], tool_format="llama3"),
|
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
@ -1130,15 +1133,18 @@ _register_template(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
format_assistant=StringFormatter(slots=["{{content}}<|eot_id|>"]),
|
||||||
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}<|eot_id|>"], tool_format="llama3"),
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=[
|
slots=[
|
||||||
(
|
(
|
||||||
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
"<|start_header_id|>ipython<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="llama3"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
default_system=(
|
default_system=(
|
||||||
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
|
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
|
||||||
@ -1147,7 +1153,7 @@ _register_template(
|
|||||||
"After completing your thoughts, you then provide a detailed explanation of the solution process "
|
"After completing your thoughts, you then provide a detailed explanation of the solution process "
|
||||||
"in your response."
|
"in your response."
|
||||||
),
|
),
|
||||||
stop_words=["<|eot_id|>"],
|
stop_words=["<|eot_id|>", "<|eom_id|>"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user