From 16c7326bc5023bc8865283a61df3f0831c9dfb59 Mon Sep 17 00:00:00 2001 From: BUAADreamer <1428195643@qq.com> Date: Tue, 10 Sep 2024 13:12:51 +0800 Subject: [PATCH] try to past test Former-commit-id: 7b4ba0efb658422fd29dca63bac1e9cee8e82af8 --- README.md | 56 ++++++++++---------- src/llamafactory/data/mm_plugin.py | 25 ++++----- src/llamafactory/model/loader.py | 8 ++- src/llamafactory/model/model_utils/visual.py | 8 ++- 4 files changed, 54 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index 963ca26a..51fbc7c1 100644 --- a/README.md +++ b/README.md @@ -160,36 +160,36 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Supported Models -| Model | Model size | Template | -|-------------------------------------------------------------------| -------------------------------- |------------------| -| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | -| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | -| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | -| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | -| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | -| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | -| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | -| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | -| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | -| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | -| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | -| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | -| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | +| Model | Model size | Template | +| ----------------------------------------------------------------- | -------------------------------- | --------- | +| [Baichuan 2](https://huggingface.co/baichuan-inc) | 7B/13B | baichuan2 | +| [BLOOM/BLOOMZ](https://huggingface.co/bigscience) | 560M/1.1B/1.7B/3B/7.1B/176B | - | +| [ChatGLM3](https://huggingface.co/THUDM) | 6B | chatglm3 | +| [Command R](https://huggingface.co/CohereForAI) | 35B/104B | cohere | +| [DeepSeek (Code/MoE)](https://huggingface.co/deepseek-ai) | 7B/16B/67B/236B | deepseek | +| [Falcon](https://huggingface.co/tiiuae) | 7B/11B/40B/180B | falcon | +| [Gemma/Gemma 2/CodeGemma](https://huggingface.co/google) | 2B/7B/9B/27B | gemma | +| [GLM-4](https://huggingface.co/THUDM) | 9B | glm4 | +| [InternLM2/InternLM2.5](https://huggingface.co/internlm) | 7B/20B | intern2 | +| [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | +| [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | +| [Llama 3/Llama 3.1](https://huggingface.co/meta-llama) | 8B/70B | llama3 | +| [LLaVA-1.5](https://huggingface.co/llava-hf) | 7B/13B | llava | | [LLaVA-NeXT](https://huggingface.co/llava-hf) | 7B/13B | llava_next | | [LLaVA-NeXT-Video](https://huggingface.co/llava-hf) | 7B/13B | llava_next_video | -| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | -| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | -| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | -| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | -| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | -| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | -| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | -| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl | -| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | -| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | -| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | -| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | -| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | +| [MiniCPM](https://huggingface.co/openbmb) | 1B/2B/4B | cpm/cpm3 | +| [Mistral/Mixtral](https://huggingface.co/mistralai) | 7B/8x7B/8x22B | mistral | +| [OLMo](https://huggingface.co/allenai) | 1B/7B | - | +| [PaliGemma](https://huggingface.co/google) | 3B | paligemma | +| [Phi-1.5/Phi-2](https://huggingface.co/microsoft) | 1.3B/2.7B | - | +| [Phi-3](https://huggingface.co/microsoft) | 4B/7B/14B | phi | +| [Qwen/Qwen1.5/Qwen2 (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/4B/7B/14B/32B/72B/110B | qwen | +| [Qwen2-VL](https://huggingface.co/Qwen) | 2B/7B | qwen2_vl | +| [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | +| [XVERSE](https://huggingface.co/xverse) | 7B/13B/65B | xverse | +| [Yi/Yi-1.5 (Code)](https://huggingface.co/01-ai) | 1.5B/6B/9B/34B | yi | +| [Yi-VL](https://huggingface.co/01-ai) | 6B/34B | yi_vl | +| [Yuan 2](https://huggingface.co/IEITYuan) | 2B/51B/102B | yuan | > [!NOTE] > For the "base" models, the `template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "instruct/chat" models. diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 22c49468..919541c6 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -296,11 +296,11 @@ class LlavaPlugin(BasePlugin): class LlavaNextPlugin(BasePlugin): @override def process_messages( - self, - messages: Sequence[Dict[str, str]], - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - processor: Optional["ProcessorMixin"], + self, + messages: Sequence[Dict[str, str]], + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + processor: Optional["ProcessorMixin"], ) -> List[Dict[str, str]]: self._validate_input(images, videos) num_image_tokens = 0 @@ -318,13 +318,13 @@ class LlavaNextPlugin(BasePlugin): @override def get_mm_inputs( - self, - images: Sequence["ImageInput"], - videos: Sequence["VideoInput"], - imglens: Sequence[int], - vidlens: Sequence[int], - seqlens: Sequence[int], - processor: Optional["ProcessorMixin"], + self, + images: Sequence["ImageInput"], + videos: Sequence["VideoInput"], + imglens: Sequence[int], + vidlens: Sequence[int], + seqlens: Sequence[int], + processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) return _get_mm_inputs(images, videos, processor) @@ -379,6 +379,7 @@ class LlavaNextVideoPlugin(BasePlugin): res.update(video_res) return res + class PaliGemmaPlugin(BasePlugin): @override def process_messages( diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 96fb5760..502af2a2 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -120,9 +120,12 @@ def load_config(model_args: "ModelArguments") -> "PretrainedConfig": """ init_kwargs = _get_init_kwargs(model_args) if "LLaVA-NeXT-Video" in model_args.model_name_or_path: - from transformers import PretrainedConfig, LlavaNextVideoConfig, CLIPVisionConfig, LlamaConfig + from transformers import CLIPVisionConfig, LlamaConfig, LlavaNextVideoConfig, PretrainedConfig + official_config = PretrainedConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) - config = LlavaNextVideoConfig(CLIPVisionConfig(**official_config.vision_config), LlamaConfig(**official_config.text_config)) + config = LlavaNextVideoConfig( + CLIPVisionConfig(**official_config.vision_config), LlamaConfig(**official_config.text_config) + ) setattr(config, "visual_inputs", True) return config return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) @@ -163,6 +166,7 @@ def load_model( load_class = AutoModelForCausalLM if "llava_next_video" == getattr(config, "model_type"): from transformers import LlavaNextVideoForConditionalGeneration + load_class = LlavaNextVideoForConditionalGeneration if model_args.train_from_scratch: diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index a850d077..32662110 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -108,7 +108,13 @@ def configure_visual_model(config: "PretrainedConfig") -> None: Patches VLMs before loading them. """ model_type = getattr(config, "model_type", None) - if model_type in ["llava", "llava_next", "video_llava", "idefics2", "llava_next_video"]: # required for ds zero3 and valuehead models + if model_type in [ + "llava", + "llava_next", + "video_llava", + "idefics2", + "llava_next_video", + ]: # required for ds zero3 and valuehead models setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) if getattr(config, "is_yi_vl_derived_model", None):