diff --git a/README.md b/README.md index dd22db35..973aac52 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-210-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Citation](https://img.shields.io/badge/citation-238-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -73,7 +73,7 @@ Choose your path: ## Features -- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, Yi, Gemma, Baichuan, ChatGLM, Phi, etc. +- **Various models**: LLaMA, LLaVA, Mistral, Mixtral-MoE, Qwen, Qwen2-VL, DeepSeek, Yi, Gemma, ChatGLM, Phi, etc. - **Integrated methods**: (Continuous) pre-training, (multimodal) supervised fine-tuning, reward modeling, PPO, DPO, KTO, ORPO, etc. - **Scalable resources**: 16-bit full-tuning, freeze-tuning, LoRA and 2/3/4/5/6/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ. - **Advanced algorithms**: [GaLore](https://github.com/jiaweizzhao/GaLore), [BAdam](https://github.com/Ledzy/BAdam), [APOLLO](https://github.com/zhuhanqing/APOLLO), [Adam-mini](https://github.com/zyushun/Adam-mini), DoRA, LongLoRA, LLaMA Pro, Mixture-of-Depths, LoRA+, LoftQ, PiSSA and Agent tuning. @@ -105,16 +105,18 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ ## Changelog +[25/01/31] We supported fine-tuning the **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** and **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** model. + [25/01/15] We supported **[APOLLO](https://arxiv.org/abs/2412.05270)** optimizer. See [examples](examples/README.md) for usage. [25/01/14] We supported fine-tuning the **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** and **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** models. Thank [@BUAADreamer](https://github.com/BUAADreamer)'s PR. +
Full Changelog + [25/01/14] We supported fine-tuning the **[InternLM3](https://huggingface.co/collections/internlm/)** models. Thank [@hhaAndroid](https://github.com/hhaAndroid)'s PR. [25/01/10] We supported fine-tuning the **[Phi-4](https://huggingface.co/microsoft/phi-4)** model. -
Full Changelog - [24/12/21] We supported using **[SwanLab](https://github.com/SwanHubX/SwanLab)** for experiment tracking and visualization. See [this section](#use-swanlab-logger) for details. [24/11/27] We supported fine-tuning the **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** model and the **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** dataset. @@ -243,7 +245,7 @@ Compared to ChatGLM's [P-Tuning](https://github.com/THUDM/ChatGLM2-6B/tree/main/ | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | -| [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | +| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 | diff --git a/README_zh.md b/README_zh.md index 2ccf75e5..079a272a 100644 --- a/README_zh.md +++ b/README_zh.md @@ -5,7 +5,7 @@ [![GitHub contributors](https://img.shields.io/github/contributors/hiyouga/LLaMA-Factory?color=orange)](https://github.com/hiyouga/LLaMA-Factory/graphs/contributors) [![GitHub workflow](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml/badge.svg)](https://github.com/hiyouga/LLaMA-Factory/actions/workflows/tests.yml) [![PyPI](https://img.shields.io/pypi/v/llamafactory)](https://pypi.org/project/llamafactory/) -[![Citation](https://img.shields.io/badge/citation-210-green)](https://scholar.google.com/scholar?cites=12620864006390196564) +[![Citation](https://img.shields.io/badge/citation-238-green)](https://scholar.google.com/scholar?cites=12620864006390196564) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -75,7 +75,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 项目特色 -- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、Yi、Gemma、Baichuan、ChatGLM、Phi 等等。 +- **多种模型**:LLaMA、LLaVA、Mistral、Mixtral-MoE、Qwen、Qwen2-VL、DeepSeek、Yi、Gemma、ChatGLM、Phi 等等。 - **集成方法**:(增量)预训练、(多模态)指令监督微调、奖励模型训练、PPO 训练、DPO 训练、KTO 训练、ORPO 训练等等。 - **多种精度**:16 比特全参数微调、冻结微调、LoRA 微调和基于 AQLM/AWQ/GPTQ/LLM.int8/HQQ/EETQ 的 2/3/4/5/6/8 比特 QLoRA 微调。 - **先进算法**:[GaLore](https://github.com/jiaweizzhao/GaLore)、[BAdam](https://github.com/Ledzy/BAdam)、[APOLLO](https://github.com/zhuhanqing/APOLLO)、[Adam-mini](https://github.com/zyushun/Adam-mini)、DoRA、LongLoRA、LLaMA Pro、Mixture-of-Depths、LoRA+、LoftQ、PiSSA 和 Agent 微调。 @@ -107,16 +107,18 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 ## 更新日志 +[25/01/31] 我们支持了 **[DeepSeek-R1](https://huggingface.co/deepseek-ai/DeepSeek-R1)** 和 **[Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct)** 模型的微调。 + [25/01/15] 我们支持了 **[APOLLO](https://arxiv.org/abs/2412.05270)** 优化器。详细用法请参照 [examples](examples/README_zh.md)。 [25/01/14] 我们支持了 **[MiniCPM-o-2.6](https://huggingface.co/openbmb/MiniCPM-o-2_6)** 和 **[MiniCPM-V-2.6](https://huggingface.co/openbmb/MiniCPM-V-2_6)** 模型的微调。 感谢 [@BUAADreamer](https://github.com/BUAADreamer) 的 PR. +
展开日志 + [25/01/14] 我们支持了 **[InternLM3](https://huggingface.co/collections/internlm/)** 模型的微调。感谢 [@hhaAndroid](https://github.com/hhaAndroid) 的 PR。 [25/01/10] 我们支持了 **[Phi-4](https://huggingface.co/microsoft/phi-4)** 模型的微调。 -
展开日志 - [24/12/21] 我们支持了使用 **[SwanLab](https://github.com/SwanHubX/SwanLab)** 跟踪与可视化实验。详细用法请参考 [此部分](#使用-swanlab-面板)。 [24/11/27] 我们支持了 **[Skywork-o1](https://huggingface.co/Skywork/Skywork-o1-Open-Llama-3.1-8B)** 模型的微调和 **[OpenO1](https://huggingface.co/datasets/O1-OPEN/OpenO1-SFT)** 数据集。 @@ -245,7 +247,7 @@ https://github.com/user-attachments/assets/e6ce34b0-52d5-4f3e-a830-592106c4c272 | [Phi-4](https://huggingface.co/microsoft) | 14B | phi4 | | [Pixtral](https://huggingface.co/mistralai) | 12B | pixtral | | [Qwen/QwQ (1-2.5) (Code/Math/MoE)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | -| [Qwen2-VL/QVQ](https://huggingface.co/Qwen) | 2B/7B/72B | qwen2_vl | +| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/72B | qwen2_vl | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | | [StarCoder 2](https://huggingface.co/bigcode) | 3B/7B/15B | - | | [TeleChat2](https://huggingface.co/Tele-AI) | 3B/7B/35B/115B | telechat2 | diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index c2e3e114..57bbb405 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -176,7 +176,10 @@ class HuggingfaceEngine(BaseEngine): if torch.is_floating_point(value): # cast data dtype for paligemma value = value.to(model.dtype) - gen_kwargs[key] = value.to(model.device) + if key == "second_per_grid_ts": # qwen2.5vl special case + gen_kwargs[key] = value.tolist() + else: + gen_kwargs[key] = value.to(model.device) if getattr(model.config, "model_type", None) in ["minicpmv", "minicpmo"]: gen_kwargs["input_ids"] = inputs diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index ba6a3da7..32037ac6 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -135,12 +135,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): features: Dict[str, "torch.Tensor"] = super().__call__(features) if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope - features["position_ids"], features["rope_deltas"] = self.model.get_rope_index( - input_ids=features["input_ids"], - image_grid_thw=mm_inputs.get("image_grid_thw", None), - video_grid_thw=mm_inputs.get("video_grid_thw", None), - attention_mask=features["attention_mask"], - ) + rope_index_kwargs = { + "input_ids": features["input_ids"], + "image_grid_thw": mm_inputs.get("image_grid_thw"), + "video_grid_thw": mm_inputs.get("video_grid_thw"), + "attention_mask": features["attention_mask"], + } + if "second_per_grid_ts" in mm_inputs: + rope_index_kwargs["second_per_grid_ts"] = mm_inputs.get("second_per_grid_ts") + + features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs) if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled cross_attention_mask = mm_inputs.pop("cross_attention_mask") diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 75bdb9d4..00945923 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -178,16 +178,16 @@ class BasePlugin: if len(images) != 0: images = self._regularize_images( images, - image_resolution=getattr(processor, "image_resolution", 512 * 512), + image_resolution=getattr(processor, "image_resolution", 768 * 768), ) input_dict["images"] = images if len(videos) != 0: videos = self._regularize_videos( videos, - image_resolution=getattr(processor, "video_resolution", 128 * 128), + image_resolution=getattr(processor, "video_resolution", 256 * 256), video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 64), + video_maxlen=getattr(processor, "video_maxlen", 128), ) input_dict["videos"] = videos @@ -501,7 +501,7 @@ class MiniCPMVPlugin(BasePlugin): if len(images) != 0: images = self._regularize_images( images, - image_resolution=getattr(processor, "image_resolution", 512 * 512), + image_resolution=getattr(processor, "image_resolution", 768 * 768), ) if "valid_image_nums_ls" in kwargs: valid_image_nums_ls = kwargs["valid_image_nums_ls"] @@ -521,9 +521,9 @@ class MiniCPMVPlugin(BasePlugin): if len(videos) != 0: videos = self._regularize_videos( videos, - image_resolution=getattr(processor, "video_resolution", 128 * 128), + image_resolution=getattr(processor, "video_resolution", 256 * 256), video_fps=getattr(processor, "video_fps", 2.0), - video_maxlen=getattr(processor, "video_maxlen", 64), + video_maxlen=getattr(processor, "video_maxlen", 128), ) video_inputs = image_processor(videos, do_pad=True, max_slice_nums=2, return_tensors="pt") mm_inputs.update(video_inputs) @@ -610,7 +610,7 @@ class MllamaPlugin(BasePlugin): """ image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") imglens: List[int] = kwargs["imglens"] - images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 512 * 512)) + images = self._regularize_images(images, image_resolution=getattr(processor, "image_resolution", 768 * 768)) batch_images = [] for image_length in imglens: batch_images.append(images[:image_length]) @@ -875,7 +875,15 @@ class Qwen2vlPlugin(BasePlugin): processor: Optional["ProcessorMixin"], ) -> Dict[str, Union[List[int], "torch.Tensor"]]: self._validate_input(images, videos) - return self._get_mm_inputs(images, videos, processor) + mm_inputs = self._get_mm_inputs(images, videos, processor) + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + if "second_per_grid_ts" in getattr(image_processor, "model_input_names", []) and "video_grid_thw" in mm_inputs: + video_fps = getattr(processor, "video_fps", 2.0) + mm_inputs["second_per_grid_ts"] = [image_processor.temporal_patch_size / video_fps] * len( + mm_inputs["video_grid_thw"] + ) + + return mm_inputs class VideoLlavaPlugin(BasePlugin): diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index c0903db1..20d84634 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -1928,6 +1928,14 @@ register_model_group( DownloadSource.DEFAULT: "Qwen/Qwen2.5-72B-Instruct", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-72B-Instruct", }, + "Qwen2.5-7B-Instruct-1M": { + DownloadSource.DEFAULT: "Qwen/Qwen2.5-7B-Instruct-1M", + DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-7B-Instruct-1M", + }, + "Qwen2.5-14B-Instruct-1M": { + DownloadSource.DEFAULT: "Qwen/Qwen2.5-14B-Instruct-1M", + DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-14B-Instruct-1M", + }, "Qwen2.5-0.5B-Instruct-GPTQ-Int8": { DownloadSource.DEFAULT: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8", @@ -2149,6 +2157,18 @@ register_model_group( DownloadSource.DEFAULT: "Qwen/QVQ-72B-Preview", DownloadSource.MODELSCOPE: "Qwen/QVQ-72B-Preview", }, + "Qwen2.5-VL-3B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-3B-Instruct", + DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-3B-Instruct", + }, + "Qwen2.5-VL-7B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-7B-Instruct", + DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-7B-Instruct", + }, + "Qwen2.5-VL-72B-Instruct": { + DownloadSource.DEFAULT: "Qwen/Qwen2.5-VL-72B-Instruct", + DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-VL-72B-Instruct", + }, }, template="qwen2_vl", vision=True, diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index f5812d51..7f4df68c 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -59,19 +59,19 @@ class ProcessorArguments: """ image_resolution: int = field( - default=512 * 512, - metadata={"help": "Keeps the number of pixels of image below this resolution."}, + default=768 * 768, + metadata={"help": "The maximum number of pixels of image inputs."}, ) video_resolution: int = field( - default=128 * 128, - metadata={"help": "Keeps the number of pixels of video below this resolution."}, + default=256 * 256, + metadata={"help": "The maximum number of pixels of video inputs."}, ) video_fps: float = field( default=2.0, metadata={"help": "The frames to sample per second for video inputs."}, ) video_maxlen: int = field( - default=64, + default=128, metadata={"help": "The maximum number of sampled frames for video inputs."}, ) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 5bd188cb..10eafff6 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -286,3 +286,11 @@ _register_composite_model( vision_model_keys=["visual.patch_embed", "visual.blocks"], language_model_keys=["model", "lm_head"], ) + + +_register_composite_model( + model_type="qwen2_5_vl", + projector_key="visual.merger", + vision_model_keys=["visual.patch_embed", "visual.blocks"], + language_model_keys=["model", "lm_head"], +)