mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
add llava-next/llava-next-video/video-llava
Former-commit-id: 6642cd501d55a1657678428ef2aa0c9b99b7e83f
This commit is contained in:
parent
c576b7ca32
commit
5aa1e847d9
@ -4,6 +4,7 @@ from io import BytesIO
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from transformers.image_utils import get_image_size, to_numpy_array
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
@ -173,7 +174,6 @@ class BasePlugin:
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 64),
|
video_maxlen=getattr(processor, "video_maxlen", 64),
|
||||||
)
|
)
|
||||||
input_dict["videos"] = videos
|
input_dict["videos"] = videos
|
||||||
|
|
||||||
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None:
|
||||||
return image_processor(**input_dict, return_tensors="pt")
|
return image_processor(**input_dict, return_tensors="pt")
|
||||||
else:
|
else:
|
||||||
@ -223,50 +223,6 @@ class BasePlugin:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class Idefics2Plugin(BasePlugin):
|
|
||||||
@override
|
|
||||||
def process_messages(
|
|
||||||
self,
|
|
||||||
messages: Sequence[Dict[str, str]],
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
self._validate_input(images, videos)
|
|
||||||
num_image_tokens = 0
|
|
||||||
messages = deepcopy(messages)
|
|
||||||
fake_image_token = processor.fake_image_token.content
|
|
||||||
image_str = f"{fake_image_token}{self.image_token * processor.image_seq_len}{fake_image_token}"
|
|
||||||
image_str = image_str * 5
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
content = message["content"]
|
|
||||||
while IMAGE_PLACEHOLDER in content:
|
|
||||||
num_image_tokens += 1
|
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
|
||||||
content = content.replace("{{image}}", image_str)
|
|
||||||
content = content.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}")
|
|
||||||
message["content"] = content
|
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
@override
|
|
||||||
def get_mm_inputs(
|
|
||||||
self,
|
|
||||||
images: Sequence["ImageInput"],
|
|
||||||
videos: Sequence["VideoInput"],
|
|
||||||
imglens: Sequence[int],
|
|
||||||
vidlens: Sequence[int],
|
|
||||||
seqlens: Sequence[int],
|
|
||||||
processor: Optional["ProcessorMixin"],
|
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
|
||||||
self._validate_input(images, videos)
|
|
||||||
return self._get_mm_inputs(images, videos, processor)
|
|
||||||
|
|
||||||
|
|
||||||
class LlavaPlugin(BasePlugin):
|
class LlavaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
@ -319,15 +275,33 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
|
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while self.image_token in content:
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(self.image_token, "{{image}}", 1)
|
||||||
|
else:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.image_token in content:
|
||||||
|
image_size = next(image_sizes)
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
num_image_tokens += 1
|
||||||
|
print(image_seqlen)
|
||||||
|
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message['content'] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
|
print(messages)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -341,8 +315,8 @@ class LlavaNextPlugin(BasePlugin):
|
|||||||
processor: Optional["ProcessorMixin"],
|
processor: Optional["ProcessorMixin"],
|
||||||
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
) -> Dict[str, Union[List[int], "torch.Tensor"]]:
|
||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
return self._get_mm_inputs(images, videos, processor)
|
res = self._get_mm_inputs(images, videos, processor)
|
||||||
|
return res
|
||||||
|
|
||||||
class LlavaNextVideoPlugin(BasePlugin):
|
class LlavaNextVideoPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@ -357,14 +331,47 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
num_video_tokens = 0
|
num_video_tokens = 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
|
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while self.image_token in content:
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(self.image_token, "{{image}}", 1)
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while self.video_token in content:
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}", 1)
|
content = content.replace(self.video_token, "{{video}}", 1)
|
||||||
|
else:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if "pixel_values" in mm_inputs:
|
||||||
|
image_sizes = iter(mm_inputs["image_sizes"])
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0]))
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
while self.image_token in content:
|
||||||
|
image_size = next(image_sizes)
|
||||||
|
orig_height, orig_width = image_size
|
||||||
|
image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width)
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1)
|
||||||
|
|
||||||
|
message['content'] = content.replace("{{image}}", self.image_token)
|
||||||
|
|
||||||
|
if "pixel_values_videos" in mm_inputs:
|
||||||
|
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
|
height, width = get_image_size(one_video[0])
|
||||||
|
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||||
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size)
|
||||||
|
video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.video_token in content:
|
||||||
|
num_video_tokens += 1
|
||||||
|
content = content.replace(self.video_token, "{{video}}", 1)
|
||||||
|
message['content'] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
@ -393,6 +400,19 @@ class LlavaNextVideoPlugin(BasePlugin):
|
|||||||
res.update(video_res)
|
res.update(video_res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||||
|
r"""
|
||||||
|
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||||
|
"""
|
||||||
|
videos = super()._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_resolution=128,
|
||||||
|
video_fps=1.0,
|
||||||
|
video_maxlen=64,
|
||||||
|
)
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
class PaliGemmaPlugin(BasePlugin):
|
class PaliGemmaPlugin(BasePlugin):
|
||||||
@override
|
@override
|
||||||
@ -561,14 +581,42 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
num_image_tokens = 0
|
num_image_tokens = 0
|
||||||
num_video_tokens = 0
|
num_video_tokens = 0
|
||||||
messages = deepcopy(messages)
|
messages = deepcopy(messages)
|
||||||
|
if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None:
|
||||||
for message in messages:
|
for message in messages:
|
||||||
content = message["content"]
|
content = message["content"]
|
||||||
while IMAGE_PLACEHOLDER in content:
|
while self.image_token in content:
|
||||||
num_image_tokens += 1
|
num_image_tokens += 1
|
||||||
content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1)
|
content = content.replace(self.image_token, "{{image}}", 1)
|
||||||
while VIDEO_PLACEHOLDER in content:
|
while self.video_token in content:
|
||||||
num_video_tokens += 1
|
num_video_tokens += 1
|
||||||
content = content.replace(VIDEO_PLACEHOLDER, "{{video}}", 1)
|
content = content.replace(self.video_token, "{{video}}", 1)
|
||||||
|
else:
|
||||||
|
mm_inputs = self._get_mm_inputs(images, videos, processor)
|
||||||
|
if "pixel_values_images" in mm_inputs.keys():
|
||||||
|
height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0]))
|
||||||
|
num_frames = 1
|
||||||
|
|
||||||
|
if "pixel_values_videos" in mm_inputs.keys():
|
||||||
|
one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0])
|
||||||
|
height, width = get_image_size(one_video[0])
|
||||||
|
num_frames = one_video.shape[0] # frame dim is always after batch dim
|
||||||
|
|
||||||
|
image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1
|
||||||
|
video_seqlen = num_image_tokens * num_frames
|
||||||
|
if processor.vision_feature_select_strategy == "default":
|
||||||
|
image_seqlen -= 1
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
content = message["content"]
|
||||||
|
while self.image_token in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.image_token, "{{image}}", 1)
|
||||||
|
while self.video_token in content:
|
||||||
|
num_image_tokens += 1
|
||||||
|
content = content.replace(self.video_token, "{{video}}", 1)
|
||||||
|
|
||||||
|
message["content"] = content.replace("{{image}}", self.image_token * image_seqlen)
|
||||||
|
message["content"] = content.replace("{{video}}", self.video_token * video_seqlen)
|
||||||
|
|
||||||
if len(images) != num_image_tokens:
|
if len(images) != num_image_tokens:
|
||||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||||
@ -591,10 +639,22 @@ class VideoLlavaPlugin(BasePlugin):
|
|||||||
self._validate_input(images, videos)
|
self._validate_input(images, videos)
|
||||||
return self._get_mm_inputs(images, videos, processor)
|
return self._get_mm_inputs(images, videos, processor)
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]:
|
||||||
|
r"""
|
||||||
|
Regularizes videos to avoid error. Including reading, resizing and converting.
|
||||||
|
"""
|
||||||
|
videos = super()._regularize_videos(
|
||||||
|
videos,
|
||||||
|
image_resolution=128,
|
||||||
|
video_fps=1.0,
|
||||||
|
video_maxlen=64,
|
||||||
|
)
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
PLUGINS = {
|
PLUGINS = {
|
||||||
"base": BasePlugin,
|
"base": BasePlugin,
|
||||||
"idefics2": Idefics2Plugin,
|
|
||||||
"llava": LlavaPlugin,
|
"llava": LlavaPlugin,
|
||||||
"llava_next": LlavaNextPlugin,
|
"llava_next": LlavaNextPlugin,
|
||||||
"llava_next_video": LlavaNextVideoPlugin,
|
"llava_next_video": LlavaNextVideoPlugin,
|
||||||
|
@ -686,16 +686,6 @@ _register_template(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
|
||||||
name="idefics2",
|
|
||||||
format_user=StringFormatter(slots=["User:{{content}}<end_of_utterance>\nAssistant:"]),
|
|
||||||
format_separator=EmptyFormatter(slots=["\n"]),
|
|
||||||
stop_words=["<end_of_utterance>"],
|
|
||||||
replace_eos=True,
|
|
||||||
mm_plugin=get_mm_plugin(name="idefics2", image_token="<image>"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_register_template(
|
_register_template(
|
||||||
name="intern",
|
name="intern",
|
||||||
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
||||||
|
@ -583,23 +583,6 @@ register_model_group(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
|
||||||
models={
|
|
||||||
"Idefics2-Base": {
|
|
||||||
DownloadSource.DEFAULT: "HuggingFaceM4/idefics2-8b-base",
|
|
||||||
},
|
|
||||||
"Idefics2-Chat": {
|
|
||||||
DownloadSource.DEFAULT: "HuggingFaceM4/idefics2-8b",
|
|
||||||
},
|
|
||||||
"Idefics2-Chatty": {
|
|
||||||
DownloadSource.DEFAULT: "HuggingFaceM4/idefics2-8b-chatty",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
template="idefics2",
|
|
||||||
vision=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
register_model_group(
|
register_model_group(
|
||||||
models={
|
models={
|
||||||
"InternLM-7B": {
|
"InternLM-7B": {
|
||||||
|
@ -119,15 +119,6 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig":
|
|||||||
Loads model config.
|
Loads model config.
|
||||||
"""
|
"""
|
||||||
init_kwargs = _get_init_kwargs(model_args)
|
init_kwargs = _get_init_kwargs(model_args)
|
||||||
if "LLaVA-NeXT-Video" in model_args.model_name_or_path:
|
|
||||||
from transformers import CLIPVisionConfig, LlamaConfig, LlavaNextVideoConfig, PretrainedConfig
|
|
||||||
|
|
||||||
official_config = PretrainedConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
|
||||||
config = LlavaNextVideoConfig(
|
|
||||||
CLIPVisionConfig(**official_config.vision_config), LlamaConfig(**official_config.text_config)
|
|
||||||
)
|
|
||||||
setattr(config, "visual_inputs", True)
|
|
||||||
return config
|
|
||||||
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs)
|
||||||
|
|
||||||
|
|
||||||
@ -164,11 +155,6 @@ def load_model(
|
|||||||
load_class = AutoModelForVision2Seq
|
load_class = AutoModelForVision2Seq
|
||||||
else:
|
else:
|
||||||
load_class = AutoModelForCausalLM
|
load_class = AutoModelForCausalLM
|
||||||
if "llava_next_video" == getattr(config, "model_type"):
|
|
||||||
from transformers import LlavaNextVideoForConditionalGeneration
|
|
||||||
|
|
||||||
load_class = LlavaNextVideoForConditionalGeneration
|
|
||||||
|
|
||||||
if model_args.train_from_scratch:
|
if model_args.train_from_scratch:
|
||||||
model = load_class.from_config(config)
|
model = load_class.from_config(config)
|
||||||
else:
|
else:
|
||||||
|
@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
|||||||
|
|
||||||
if getattr(model, "quantization_method", None):
|
if getattr(model, "quantization_method", None):
|
||||||
model_type = getattr(model.config, "model_type", None)
|
model_type = getattr(model.config, "model_type", None)
|
||||||
if model_type in ["llava", "paligemma"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||||
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
|
||||||
@ -111,9 +111,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
|||||||
if model_type in [
|
if model_type in [
|
||||||
"llava",
|
"llava",
|
||||||
"llava_next",
|
"llava_next",
|
||||||
"video_llava",
|
|
||||||
"idefics2",
|
|
||||||
"llava_next_video",
|
"llava_next_video",
|
||||||
|
"video_llava",
|
||||||
]: # required for ds zero3 and valuehead models
|
]: # required for ds zero3 and valuehead models
|
||||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||||
|
|
||||||
@ -128,7 +127,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
|||||||
"""
|
"""
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
forbidden_modules = set()
|
forbidden_modules = set()
|
||||||
if model_type in ["llava", "paligemma"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
forbidden_modules.add("vision_tower")
|
forbidden_modules.add("vision_tower")
|
||||||
|
|
||||||
@ -170,7 +169,7 @@ def patch_target_modules(
|
|||||||
"""
|
"""
|
||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
if finetuning_args.freeze_vision_tower:
|
if finetuning_args.freeze_vision_tower:
|
||||||
if model_type in ["llava", "paligemma"]:
|
if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]:
|
||||||
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
|
||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user