diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 40414399..fbcfd46a 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -4,6 +4,7 @@ from io import BytesIO from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union import numpy as np +from transformers.image_utils import get_image_size, to_numpy_array from typing_extensions import override from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER @@ -173,7 +174,6 @@ class BasePlugin: video_maxlen=getattr(processor, "video_maxlen", 64), ) input_dict["videos"] = videos - if input_dict.get("images", None) is not None or input_dict.get("videos", None) is not None: return image_processor(**input_dict, return_tensors="pt") else: @@ -223,50 +223,6 @@ class BasePlugin: return {} -class Idefics2Plugin(BasePlugin): - @override - def process_messages( - self, - messages: Sequence[Dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - processor: Optional["ProcessorMixin"], - ) -> List[Dict[str, str]]: - self._validate_input(images, videos) - num_image_tokens = 0 - messages = deepcopy(messages) - fake_image_token = processor.fake_image_token.content - image_str = f"{fake_image_token}{self.image_token * processor.image_seq_len}{fake_image_token}" - image_str = image_str * 5 - - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) - content = content.replace("{{image}}", image_str) - content = content.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}") - message["content"] = content - - if len(images) != num_image_tokens: - raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) - - return messages - - @override - def get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - seqlens: Sequence[int], - processor: Optional["ProcessorMixin"], - ) -> Dict[str, Union[List[int], "torch.Tensor"]]: - self._validate_input(images, videos) - return self._get_mm_inputs(images, videos, processor) - - class LlavaPlugin(BasePlugin): @override def process_messages( @@ -319,15 +275,33 @@ class LlavaNextPlugin(BasePlugin): self._validate_input(images, videos) num_image_tokens = 0 messages = deepcopy(messages) - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None: + for message in messages: + content = message["content"] + while self.image_token in content: + num_image_tokens += 1 + content = content.replace(self.image_token, "{{image}}", 1) + else: + mm_inputs = self._get_mm_inputs(images, videos, processor) + image_sizes = iter(mm_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + for message in messages: + content = message["content"] + while self.image_token in content: + image_size = next(image_sizes) + orig_height, orig_width = image_size + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + num_image_tokens += 1 + print(image_seqlen) + content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) + + message['content'] = content.replace("{{image}}", self.image_token) if len(images) != num_image_tokens: raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) - + print(messages) return messages @override @@ -341,8 +315,8 @@ class LlavaNextPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) - return self._get_mm_inputs(images, videos, processor) - + res = self._get_mm_inputs(images, videos, processor) + return res class LlavaNextVideoPlugin(BasePlugin): @override @@ -357,14 +331,47 @@ class LlavaNextVideoPlugin(BasePlugin): num_image_tokens = 0 num_video_tokens = 0 messages = deepcopy(messages) - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) - while VIDEO_PLACEHOLDER in content: - num_video_tokens += 1 - content = content.replace(VIDEO_PLACEHOLDER, "{{video}}", 1) + if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None: + for message in messages: + content = message["content"] + while self.image_token in content: + num_image_tokens += 1 + content = content.replace(self.image_token, "{{image}}", 1) + while self.video_token in content: + num_video_tokens += 1 + content = content.replace(self.video_token, "{{video}}", 1) + else: + mm_inputs = self._get_mm_inputs(images, videos, processor) + if "pixel_values" in mm_inputs: + image_sizes = iter(mm_inputs["image_sizes"]) + height, width = get_image_size(to_numpy_array(mm_inputs["pixel_values"][0][0])) + for message in messages: + content = message["content"] + + while self.image_token in content: + image_size = next(image_sizes) + orig_height, orig_width = image_size + image_seqlen = processor._get_number_of_features(orig_height, orig_width, height, width) + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + num_image_tokens += 1 + content = content.replace(self.image_token, "{{image}}" * image_seqlen, 1) + + message['content'] = content.replace("{{image}}", self.image_token) + + if "pixel_values_videos" in mm_inputs: + one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + video_seqlen = image_seqlen // 4 * num_frames # divide by 4 needed for avg pooling layer + + for message in messages: + content = message["content"] + while self.video_token in content: + num_video_tokens += 1 + content = content.replace(self.video_token, "{{video}}", 1) + message['content'] = content.replace("{{video}}", self.video_token * video_seqlen) if len(images) != num_image_tokens: raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) @@ -393,6 +400,19 @@ class LlavaNextVideoPlugin(BasePlugin): res.update(video_res) return res + @override + def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: + r""" + Regularizes videos to avoid error. Including reading, resizing and converting. + """ + videos = super()._regularize_videos( + videos, + image_resolution=128, + video_fps=1.0, + video_maxlen=64, + ) + return videos + class PaliGemmaPlugin(BasePlugin): @override @@ -561,14 +581,42 @@ class VideoLlavaPlugin(BasePlugin): num_image_tokens = 0 num_video_tokens = 0 messages = deepcopy(messages) - for message in messages: - content = message["content"] - while IMAGE_PLACEHOLDER in content: - num_image_tokens += 1 - content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) - while VIDEO_PLACEHOLDER in content: - num_video_tokens += 1 - content = content.replace(VIDEO_PLACEHOLDER, "{{video}}", 1) + if getattr(processor, "patch_size") is None or getattr(processor, "vision_feature_select_strategy") is None: + for message in messages: + content = message["content"] + while self.image_token in content: + num_image_tokens += 1 + content = content.replace(self.image_token, "{{image}}", 1) + while self.video_token in content: + num_video_tokens += 1 + content = content.replace(self.video_token, "{{video}}", 1) + else: + mm_inputs = self._get_mm_inputs(images, videos, processor) + if "pixel_values_images" in mm_inputs.keys(): + height, width = get_image_size(to_numpy_array(mm_inputs.get("pixel_values_images")[0])) + num_frames = 1 + + if "pixel_values_videos" in mm_inputs.keys(): + one_video = to_numpy_array(mm_inputs.get("pixel_values_videos")[0]) + height, width = get_image_size(one_video[0]) + num_frames = one_video.shape[0] # frame dim is always after batch dim + + image_seqlen = (height // processor.patch_size) * (width // processor.patch_size) + 1 + video_seqlen = num_image_tokens * num_frames + if processor.vision_feature_select_strategy == "default": + image_seqlen -= 1 + + for message in messages: + content = message["content"] + while self.image_token in content: + num_image_tokens += 1 + content = content.replace(self.image_token, "{{image}}", 1) + while self.video_token in content: + num_image_tokens += 1 + content = content.replace(self.video_token, "{{video}}", 1) + + message["content"] = content.replace("{{image}}", self.image_token * image_seqlen) + message["content"] = content.replace("{{video}}", self.video_token * video_seqlen) if len(images) != num_image_tokens: raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) @@ -591,10 +639,22 @@ class VideoLlavaPlugin(BasePlugin): self._validate_input(images, videos) return self._get_mm_inputs(images, videos, processor) + @override + def _regularize_videos(self, videos: Sequence["VideoInput"], **kwargs) -> List[List["ImageObject"]]: + r""" + Regularizes videos to avoid error. Including reading, resizing and converting. + """ + videos = super()._regularize_videos( + videos, + image_resolution=128, + video_fps=1.0, + video_maxlen=64, + ) + return videos + PLUGINS = { "base": BasePlugin, - "idefics2": Idefics2Plugin, "llava": LlavaPlugin, "llava_next": LlavaNextPlugin, "llava_next_video": LlavaNextVideoPlugin, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 99fca395..2d966155 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -686,16 +686,6 @@ _register_template( ) -_register_template( - name="idefics2", - format_user=StringFormatter(slots=["User:{{content}}\nAssistant:"]), - format_separator=EmptyFormatter(slots=["\n"]), - stop_words=[""], - replace_eos=True, - mm_plugin=get_mm_plugin(name="idefics2", image_token=""), -) - - _register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index a3667249..335d222a 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -583,23 +583,6 @@ register_model_group( ) -register_model_group( - models={ - "Idefics2-Base": { - DownloadSource.DEFAULT: "HuggingFaceM4/idefics2-8b-base", - }, - "Idefics2-Chat": { - DownloadSource.DEFAULT: "HuggingFaceM4/idefics2-8b", - }, - "Idefics2-Chatty": { - DownloadSource.DEFAULT: "HuggingFaceM4/idefics2-8b-chatty", - }, - }, - template="idefics2", - vision=True, -) - - register_model_group( models={ "InternLM-7B": { diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 502af2a2..c2fdb2dd 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -119,15 +119,6 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": Loads model config. """ init_kwargs = _get_init_kwargs(model_args) - if "LLaVA-NeXT-Video" in model_args.model_name_or_path: - from transformers import CLIPVisionConfig, LlamaConfig, LlavaNextVideoConfig, PretrainedConfig - - official_config = PretrainedConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) - config = LlavaNextVideoConfig( - CLIPVisionConfig(**official_config.vision_config), LlamaConfig(**official_config.text_config) - ) - setattr(config, "visual_inputs", True) - return config return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) @@ -164,11 +155,6 @@ def load_model( load_class = AutoModelForVision2Seq else: load_class = AutoModelForCausalLM - if "llava_next_video" == getattr(config, "model_type"): - from transformers import LlavaNextVideoForConditionalGeneration - - load_class = LlavaNextVideoForConditionalGeneration - if model_args.train_from_scratch: model = load_class.from_config(config) else: diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 32662110..b5b581bb 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -92,7 +92,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen if getattr(model, "quantization_method", None): model_type = getattr(model.config, "model_type", None) - if model_type in ["llava", "paligemma"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector") elif model_type == "qwen2_vl": mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger") @@ -111,9 +111,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None: if model_type in [ "llava", "llava_next", - "video_llava", - "idefics2", "llava_next_video", + "video_llava", ]: # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) @@ -128,7 +127,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni """ model_type = getattr(config, "model_type", None) forbidden_modules = set() - if model_type in ["llava", "paligemma"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: if finetuning_args.freeze_vision_tower: forbidden_modules.add("vision_tower") @@ -170,7 +169,7 @@ def patch_target_modules( """ model_type = getattr(config, "model_type", None) if finetuning_args.freeze_vision_tower: - if model_type in ["llava", "paligemma"]: + if model_type in ["llava", "llava_next", "llava_next_video", "paligemma", "video_llava"]: return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) elif model_type == "qwen2_vl": return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))