mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
add some
Former-commit-id: 79c2d7090cbf364063ea3608814ab18aa27fdc87
This commit is contained in:
parent
f318dc9464
commit
b5ef5059ee
@ -1,4 +1,4 @@
|
|||||||
transformers>=4.41.2,<=4.46.1
|
transformers>=4.41.2
|
||||||
datasets>=2.16.0,<=3.1.0
|
datasets>=2.16.0,<=3.1.0
|
||||||
accelerate>=0.34.0,<=1.0.1
|
accelerate>=0.34.0,<=1.0.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
|
@ -149,6 +149,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
features.update(mm_inputs)
|
features.update(mm_inputs)
|
||||||
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
|
||||||
features = features.data # use default_collate() instead of BatchEncoding.to()
|
features = features.data # use default_collate() instead of BatchEncoding.to()
|
||||||
|
if "image_bound" in features:
|
||||||
|
input_ids, position_ids = features['input_ids'], features['position_ids']
|
||||||
|
features['position_ids'] = F.pad(position_ids, (0, input_ids.shape[-1] - position_ids.shape[-1]))
|
||||||
|
new_features = {}
|
||||||
|
new_features.update({"data": features})
|
||||||
|
new_features.update(features)
|
||||||
|
features = new_features
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import math
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
import re
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -249,6 +250,130 @@ class BasePlugin:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class CpmOPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
|
def process_messages(
|
||||||
|
self,
|
||||||
|
messages: Sequence[Dict[str, str]],
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
num_image_tokens = 0
|
||||||
|
messages = deepcopy(messages)
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while IMAGE_PLACEHOLDER in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", "(<image>./</image>)")
|
||||||
|
|
||||||
|
if num_image_tokens>0:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
pattern = "(<image>./</image>)"
|
||||||
|
images, image_sizes, tgt_sizes = mm_inputs["pixel_values"], mm_inputs["image_sizes"], mm_inputs["tgt_sizes"]
|
||||||
|
|
||||||
|
input_ids_list = []
|
||||||
|
image_bounds_list = []
|
||||||
|
image_index = 0
|
||||||
|
for index, message in enumerate(messages):
|
||||||
|
text = message['content']
|
||||||
|
image_tags = re.findall(pattern, text)
|
||||||
|
text_chunks = text.split(pattern)
|
||||||
|
final_text = ""
|
||||||
|
for i in range(len(image_tags)):
|
||||||
|
final_text = final_text + text_chunks[i] + \
|
||||||
|
image_processor.get_slice_image_placeholder(
|
||||||
|
image_sizes[image_index][i],
|
||||||
|
i,
|
||||||
|
image_processor.max_slice_nums,
|
||||||
|
image_processor.use_image_id,
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
final_text += text_chunks[-1]
|
||||||
|
messages[index]['content'] = final_text
|
||||||
|
# print(messages)
|
||||||
|
|
||||||
|
if len(images) != num_image_tokens:
|
||||||
|
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens.")
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
processor: "ProcessorMixin",
|
||||||
|
) -> Dict[str, "torch.Tensor"]:
|
||||||
|
|
||||||
|
image_processor: "BaseImageProcessor" = getattr(processor, "image_processor")
|
||||||
|
|
||||||
|
mm_inputs = {}
|
||||||
|
|
||||||
|
if len(images) != 0:
|
||||||
|
images = self._regularize_images(
|
||||||
|
images,
|
||||||
|
image_resolution=getattr(processor, "image_resolution", 512 * 512),
|
||||||
|
)
|
||||||
|
image_inputs = image_processor(images, do_pad=True, max_slice_nums=image_processor.max_slice_nums, return_tensors="pt")
|
||||||
|
mm_inputs.update(image_inputs)
|
||||||
|
|
||||||
|
if len(videos) != 0:
|
||||||
|
videos = self._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_resolution=getattr(processor, "video_resolution", 128 * 128),
|
||||||
|
video_fps=getattr(processor, "video_fps", 2.0),
|
||||||
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||||
|
)
|
||||||
|
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get_mm_inputs(
|
||||||
|
self,
|
||||||
|
images: Sequence["ImageInput"],
|
||||||
|
videos: Sequence["VideoInput"],
|
||||||
|
imglens: Sequence[int],
|
||||||
|
vidlens: Sequence[int],
|
||||||
|
batch_ids: Sequence[List[int]],
|
||||||
|
processor: Optional["ProcessorMixin"],
|
||||||
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
|
self._validate_input(images, videos)
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
image_bounds_list = []
|
||||||
|
position_ids = []
|
||||||
|
for input_ids in batch_ids:
|
||||||
|
input_ids_ = torch.tensor(input_ids)
|
||||||
|
start_cond = (input_ids_ == processor.tokenizer.im_start_id) | (input_ids_ == processor.tokenizer.slice_start_id)
|
||||||
|
end_cond = (input_ids_ == processor.tokenizer.im_end_id) | (input_ids_ == processor.tokenizer.slice_end_id)
|
||||||
|
image_start_tokens = torch.where(start_cond)[0]
|
||||||
|
image_start_tokens += 1
|
||||||
|
image_end_tokens = torch.where(end_cond)[0]
|
||||||
|
valid_image_nums = max(len(image_start_tokens), len(image_end_tokens))
|
||||||
|
image_bounds = torch.hstack(
|
||||||
|
[
|
||||||
|
image_start_tokens[:valid_image_nums].unsqueeze(-1),
|
||||||
|
image_end_tokens[:valid_image_nums].unsqueeze(-1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
image_bounds_list.append(image_bounds)
|
||||||
|
position_ids_ = list(range(input_ids_.size(0)))
|
||||||
|
# print(input_ids_.shape, len(position_ids_)
|
||||||
|
position_ids.append(position_ids_)
|
||||||
|
position_ids = torch.tensor(position_ids, dtype=torch.int64)
|
||||||
|
mm_inputs.update({
|
||||||
|
"image_bound": image_bounds_list,
|
||||||
|
"position_ids": position_ids,
|
||||||
|
})
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
|
||||||
class LlavaPlugin(BasePlugin):
|
class LlavaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -790,6 +915,7 @@ class MllamaPlugin(BasePlugin):
|
|||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
|
"cpm_o": CpmOPlugin,
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
"llava_next_video": LlavaNextVideoPlugin,
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
|
@ -583,6 +583,22 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_register_template(
|
||||||
|
name="cpm_o",
|
||||||
|
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
||||||
|
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
||||||
|
format_function=FunctionFormatter(slots=["{{content}}", "<|im_end|>"], tool_format="qwen"),
|
||||||
|
format_observation=StringFormatter(
|
||||||
|
slots=["<|im_start|>user\n<tool_response>\n{{content}}\n</tool_response><|im_end|>\n<|im_start|>assistant\n"]
|
||||||
|
),
|
||||||
|
format_tools=ToolFormatter(tool_format="qwen"),
|
||||||
|
format_separator=EmptyFormatter(slots=["\n"]),
|
||||||
|
default_system="You are a helpful assistant.",
|
||||||
|
stop_words=["<|im_end|>"],
|
||||||
|
mm_plugin=get_mm_plugin(name="cpm_o", image_token="<image>", video_token="<video>"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# copied from chatml template
|
# copied from chatml template
|
||||||
_register_template(
|
_register_template(
|
||||||
name="dbrx",
|
name="dbrx",
|
||||||
|
@ -1141,6 +1141,17 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_group(
|
||||||
|
models={
|
||||||
|
"MiniCPM-V-2_6-Chat": {
|
||||||
|
DownloadSource.DEFAULT: "openbmb/MiniCPM-V-2_6",
|
||||||
|
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-V-2_6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
template="cpm_o",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"Mistral-7B-v0.1": {
|
"Mistral-7B-v0.1": {
|
||||||
|
@ -81,7 +81,7 @@ def check_dependencies() -> None:
|
|||||||
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
return
|
return
|
||||||
|
|
||||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
|
||||||
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
|
require_version("datasets>=2.16.0,<=3.1.0", "To fix: pip install datasets>=2.16.0,<=3.1.0")
|
||||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||||
|
@ -44,6 +44,8 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
forbidden_modules.add("vision_model")
|
forbidden_modules.add("vision_model")
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
forbidden_modules.add("visual")
|
forbidden_modules.add("visual")
|
||||||
|
elif model_type in ["minicpmv"]:
|
||||||
|
forbidden_modules.add("vpm")
|
||||||
else:
|
else:
|
||||||
forbidden_modules.add("vision_tower")
|
forbidden_modules.add("vision_tower")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user