mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
add some
Former-commit-id: 771cc802941cf1953b32e5102c817c6a3090b5ce
This commit is contained in:
parent
bcbe37ff52
commit
0fb50f9c88
@ -19,8 +19,8 @@ from dataclasses import dataclass
|
|||||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER
|
||||||
@ -106,7 +106,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
batch_vidlens.append(len(videos))
|
batch_vidlens.append(len(videos))
|
||||||
batch_input_ids.append(feature["input_ids"])
|
batch_input_ids.append(feature["input_ids"])
|
||||||
|
|
||||||
if self.processor is not None and sum(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case
|
if (
|
||||||
|
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
|
||||||
|
): # avoid process hanging in zero3/fsdp case
|
||||||
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
|
||||||
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
|
||||||
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
|
||||||
@ -157,10 +159,14 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
if "image_bound" in features: # for minicpmv inputs
|
if "image_bound" in features: # for minicpmv inputs
|
||||||
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
features["position_ids"] = [torch.arange(input_ids.size(0)).long() for input_ids in features["input_ids"]]
|
||||||
features["input_ids"] = pad(features["input_ids"],)
|
features["input_ids"] = pad(
|
||||||
|
features["input_ids"],
|
||||||
|
)
|
||||||
features["position_ids"] = pad(features["position_ids"])
|
features["position_ids"] = pad(features["position_ids"])
|
||||||
features["labels"] = pad(features["labels"], padding_value=-100)
|
features["labels"] = pad(features["labels"], padding_value=-100)
|
||||||
features["attention_mask"] = pad(features["attention_mask"],)
|
features["attention_mask"] = pad(
|
||||||
|
features["attention_mask"],
|
||||||
|
)
|
||||||
new_features = {}
|
new_features = {}
|
||||||
new_features.update({"data": features})
|
new_features.update({"data": features})
|
||||||
new_features.update(features)
|
new_features.update(features)
|
||||||
|
@ -265,8 +265,19 @@ class CpmOPlugin(BasePlugin):
|
|||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
|
num_video_tokens = 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
mm_inputs = {}
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
assert len(images) == 0, "Only support video and image sft seperately"
|
||||||
|
max_slice_nums = 2
|
||||||
|
use_image_id = False
|
||||||
|
mm_inputs = self._get_mm_inputs([], videos, processor)
|
||||||
|
else:
|
||||||
|
max_slice_nums = image_processor.max_slice_nums
|
||||||
|
use_image_id = image_processor.use_image_id
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
@ -274,15 +285,21 @@ class CpmOPlugin(BasePlugin):
|
|||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
|
|
||||||
|
while VIDEO_PLACEHOLDER in content:
|
||||||
|
num_video_tokens += 1
|
||||||
|
content = content.replace(
|
||||||
|
VIDEO_PLACEHOLDER, "{{image}}" * len(mm_inputs["pixel_values"][num_video_tokens - 1]), 1
|
||||||
|
)
|
||||||
|
|
||||||
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||||
|
|
||||||
if num_image_tokens > 0:
|
if num_image_tokens > 0:
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
mm_inputs = self._get_mm_inputs(images, [], processor)
|
||||||
|
|
||||||
|
if mm_inputs:
|
||||||
pattern = "(<image>./</image>)"
|
pattern = "(<image>./</image>)"
|
||||||
images, image_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"]
|
image_sizes = mm_inputs["image_sizes"]
|
||||||
|
|
||||||
image_index = 0
|
|
||||||
for index, message in enumerate(messages):
|
for index, message in enumerate(messages):
|
||||||
text = message["content"]
|
text = message["content"]
|
||||||
image_tags = re.findall(pattern, text)
|
image_tags = re.findall(pattern, text)
|
||||||
@ -293,19 +310,21 @@ class CpmOPlugin(BasePlugin):
|
|||||||
final_text
|
final_text
|
||||||
+ text_chunks[i]
|
+ text_chunks[i]
|
||||||
+ image_processor.get_slice_image_placeholder(
|
+ image_processor.get_slice_image_placeholder(
|
||||||
image_sizes[image_index][i],
|
image_sizes[0][i],
|
||||||
i,
|
i,
|
||||||
image_processor.max_slice_nums,
|
max_slice_nums,
|
||||||
image_processor.use_image_id,
|
use_image_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
image_index += 1
|
|
||||||
final_text += text_chunks[-1]
|
final_text += text_chunks[-1]
|
||||||
messages[index]["content"] = final_text
|
messages[index]["content"] = final_text
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
if len(videos) != num_video_tokens:
|
||||||
|
raise ValueError(f"The number of videos does not match the number of {VIDEO_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -333,7 +352,7 @@ class CpmOPlugin(BasePlugin):
|
|||||||
new_images.append(images[idx : idx + valid_image_nums])
|
new_images.append(images[idx : idx + valid_image_nums])
|
||||||
idx += valid_image_nums
|
idx += valid_image_nums
|
||||||
images = new_images
|
images = new_images
|
||||||
|
|
||||||
image_inputs = image_processor(
|
image_inputs = image_processor(
|
||||||
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt"
|
||||||
)
|
)
|
||||||
@ -346,6 +365,8 @@ class CpmOPlugin(BasePlugin):
|
|||||||
video_fps=getattr(processor, "video_fps", 2.0),
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||||
)
|
)
|
||||||
|
video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt")
|
||||||
|
mm_inputs.update(video_inputs)
|
||||||
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
@ -380,12 +401,9 @@ class CpmOPlugin(BasePlugin):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
image_bounds_list.append(image_bounds)
|
image_bounds_list.append(image_bounds)
|
||||||
|
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
mm_inputs = self._get_mm_inputs(images, videos, processor, valid_image_nums_ls=valid_image_nums_ls)
|
||||||
mm_inputs.update(
|
mm_inputs.update({"image_bound": image_bounds_list})
|
||||||
{
|
|
||||||
"image_bound": image_bounds_list,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user