From f00f4ae9b6c944220a66538ac54c7d2a18d6474d Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Tue, 10 Sep 2024 12:31:53 +0800 Subject: [PATCH 01/33] support llava-next(video) Former-commit-id: 31259e7e0caa9ff6449b4abcee0554e211167178 --- README.md | 58 +++--- README_zh.md | 2 + requirements.txt | 1 + setup.py | 1 + src/llamafactory/data/mm_plugin.py | 178 +++++++++++++++++++ src/llamafactory/data/template.py | 43 +++++ src/llamafactory/extras/constants.py | 71 ++++++++ src/llamafactory/model/loader.py | 12 +- src/llamafactory/model/model_utils/misc.py | 2 +- src/llamafactory/model/model_utils/visual.py | 4 +- tests/data/test_mm_plugin.py | 55 +++++- 11 files changed, 394 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 8bc99730..963ca26a 100644 --- a/README.md +++ b/README.md @@ -160,34 +160,36 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Supported Models -| Model | Model size | Template | -| ----------------------------------------------------------------- | -------------------------------- | --------- | -| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | -| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | -| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | -| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | -| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | -| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | -| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | -| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | -| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | -| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | -| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | -| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | -| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | -| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | -| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | -| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | -| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl | -| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | -| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | -| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | -| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | -| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | +| Model | Model size | Template | +|-------------------------------------------------------------------| -------------------------------- |------------------| +| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | +| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/13B | llava_next | +| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/13B | llava_next_video | +| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | +| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | +| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl | +| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | > [!NOTE] > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. diff --git a/README_zh.md b/README_zh.md index e80a2104..251b1f87 100644 --- a/README_zh.md +++ b/README_zh.md @@ -176,6 +176,8 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | | [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | | [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | +| [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/13B | llava_next | +| [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/13B | llava_next_video | | [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | | [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | | [OLMo](https://huggingface.co/allenai) | 1B/7B | - | diff --git a/requirements.txt b/requirements.txt index 54d58bb3..1c1b4c55 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ fire packaging pyyaml numpy<2.0.0 +av \ No newline at end of file diff --git a/setup.py b/setup.py index a80cb81b..5e969e51 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ extra_require = { "qwen": ["transformers_stream_generator"], "modelscope": ["modelscope"], "dev": ["ruff", "pytest"], + "av": ["av>=13.0.0"], } diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index c109d26e..22c49468 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -209,6 +209,50 @@ class BasePlugin: return {} +class Idefics2Plugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + messages = deepcopy(messages) + fake_image_token = processor.fake_image_token.content + image_str = f"{fake_image_token}{self.image_token * processor.image_seq_len}{fake_image_token}" + image_str = image_str * 5 + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + num_image_tokens += 1 + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + content = content.replace("{{image}}", image_str) + content = content.replace(f"{fake_image_token}{fake_image_token}", f"{fake_image_token}") + message["content"] = content + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + return _get_mm_inputs(images, videos, processor) + + class LlavaPlugin(BasePlugin): @override def process_messages( @@ -249,6 +293,92 @@ class LlavaPlugin(BasePlugin): return _get_mm_inputs(images, videos, processor) +class LlavaNextPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + num_image_tokens += 1 + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + return _get_mm_inputs(images, videos, processor) + + +class LlavaNextVideoPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + num_video_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + num_image_tokens += 1 + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + while VIDEO_PLACEHOLDER in content: + num_video_tokens += 1 + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}", 1) + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + if len(videos) != num_video_tokens: + raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + video_processor = getattr(processor, "video_processor") + res = _get_mm_inputs(images, [], processor) + if len(videos) != 0: + videos = _regularize_videos(videos, processor) + video_res = video_processor(videos, return_tensors="pt") + res.update(video_res) + return res + class PaliGemmaPlugin(BasePlugin): @override def process_messages( @@ -380,11 +510,59 @@ class Qwen2vlPlugin(BasePlugin): return _get_mm_inputs(images, videos, processor) +class VideoLlavaPlugin(BasePlugin): + @override + def process_messages( + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], + ) -> List[Dict[str, str]]: + self._validate_input(images, videos) + num_image_tokens = 0 + num_video_tokens = 0 + messages = deepcopy(messages) + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + num_image_tokens += 1 + content = content.replace(IMAGE_PLACEHOLDER, "{{image}}", 1) + while VIDEO_PLACEHOLDER in content: + num_video_tokens += 1 + content = content.replace(VIDEO_PLACEHOLDER, "{{video}}", 1) + + if len(images) != num_image_tokens: + raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + if len(videos) != num_video_tokens: + raise ValueError("The number of videos does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + + return messages + + @override + def get_mm_inputs( + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], + ) -> Dict[str, Union[List[int], "torch.Tensor"]]: + self._validate_input(images, videos) + return _get_mm_inputs(images, videos, processor) + + PLUGINS = { "base": BasePlugin, + "idefics2": Idefics2Plugin, "llava": LlavaPlugin, + "llava_next": LlavaNextPlugin, + "llava_next_video": LlavaNextVideoPlugin, "paligemma": PaliGemmaPlugin, "qwen2_vl": Qwen2vlPlugin, + "video_llava": VideoLlavaPlugin, } diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index ff5e32d2..7bf164e6 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -680,6 +680,16 @@ _register_template( ) +_register_template( + name="idefics2", + format_user=StringFormatter(slots=["User:{{content}}\nAssistant:"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=[""], + replace_eos=True, + mm_plugin=get_mm_plugin(name="idefics2", image_token=""), +) + + _register_template( name="intern", format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), @@ -753,6 +763,28 @@ _register_template( ) +_register_template( + name="llava_next", + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava_next", image_token=""), +) + + +_register_template( + name="llava_next_video", + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), + mm_plugin=get_mm_plugin(name="llava_next_video", image_token="", video_token="