From b0c8ba73e0378f0052143eec2bcacf6ee2a88ce6 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Wed, 21 May 2025 05:16:18 +0800 Subject: [PATCH] [deps] update to transformers 4.52 (#8125) --- .github/workflows/tests.yml | 8 + README.md | 5 +- README_zh.md | 5 +- data/README.md | 4 +- data/README_zh.md | 4 +- requirements.txt | 2 +- setup.py | 2 +- src/llamafactory/data/mm_plugin.py | 18 +- src/llamafactory/data/template.py | 26 ++- src/llamafactory/extras/constants.py | 8 + src/llamafactory/extras/misc.py | 4 +- src/llamafactory/hparams/data_args.py | 2 +- src/llamafactory/hparams/model_args.py | 8 +- src/llamafactory/model/model_utils/visual.py | 5 +- src/llamafactory/model/patcher.py | 2 +- src/llamafactory/train/dpo/trainer.py | 4 +- src/llamafactory/train/kto/trainer.py | 5 +- src/llamafactory/train/pt/trainer.py | 4 +- src/llamafactory/train/rm/trainer.py | 4 +- src/llamafactory/train/sft/trainer.py | 4 +- src/llamafactory/webui/common.py | 8 + src/llamafactory/webui/components/train.py | 44 +++- src/llamafactory/webui/locales.py | 210 +++++++++++++++++-- src/llamafactory/webui/runner.py | 15 +- tests/data/test_mm_plugin.py | 15 +- tests/data/test_template.py | 32 +-- tests/model/model_utils/test_visual.py | 24 ++- tests/version.txt | 2 +- 28 files changed, 365 insertions(+), 109 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 84920f4b..414cfe1a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,6 +40,9 @@ jobs: - python: "3.9" os: "ubuntu-latest" transformers: "4.49.0" + - python: "3.9" + os: "ubuntu-latest" + transformers: "4.51.0" runs-on: ${{ matrix.os }} @@ -72,6 +75,11 @@ jobs: run: | python -m pip install "transformers==${{ matrix.transformers }}" + - name: Downgrade transformers + if: ${{ matrix.os == 'macos-13' }} + run: | + python -m pip install "transformers<4.52.0" + - name: Cache files id: hf-hub-cache uses: actions/cache@v4 diff --git a/README.md b/README.md index c02b7750..2d6eeef7 100644 --- a/README.md +++ b/README.md @@ -266,7 +266,7 @@ Choose your path: | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | -| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\* | 1B/2B/8B/14B/38B/78B | intern_vl | +| [InternVL 2.5-3](https://huggingface.co/OpenGVLab) | 1B/2B/8B/14B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | @@ -292,7 +292,7 @@ Choose your path: | [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | -| [Qwen2.5-Omni](https://huggingface.co/Qwen)\* | 3B/7B | qwen2_omni | +| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | | [Seed Coder](https://huggingface.co/ByteDance-Seed) | 8B | seed_coder | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | @@ -439,6 +439,7 @@ huggingface-cli login | ------------ | ------- | --------- | | python | 3.9 | 3.10 | | torch | 2.0.0 | 2.6.0 | +| torchvision | 0.15.0 | 0.21.0 | | transformers | 4.45.0 | 4.50.0 | | datasets | 2.16.0 | 3.2.0 | | accelerate | 0.34.0 | 1.2.1 | diff --git a/README_zh.md b/README_zh.md index f9029948..31c77b75 100644 --- a/README_zh.md +++ b/README_zh.md @@ -268,7 +268,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Hunyuan](https://huggingface.co/tencent/) | 7B | hunyuan | | [Index](https://huggingface.co/IndexTeam) | 1.9B | index | | [InternLM 2-3](https://huggingface.co/internlm) | 7B/8B/20B | intern2 | -| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\* | 1B/2B/8B/14B/38B/78B | intern_vl | +| [InternVL 2.5-3](https://huggingface.co/OpenGVLab) | 1B/2B/8B/14B/38B/78B | intern_vl | | [Kimi-VL](https://huggingface.co/moonshotai) | 16B | kimi_vl | | [Llama](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | - | | [Llama 2](https://huggingface.co/meta-llama) | 7B/13B/70B | llama2 | @@ -294,7 +294,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc | [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen) | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen | | [Qwen3 (MoE)](https://huggingface.co/Qwen) | 0.6B/1.7B/4B/8B/14B/32B/235B | qwen3 | | [Qwen2-Audio](https://huggingface.co/Qwen) | 7B | qwen2_audio | -| [Qwen2.5-Omni](https://huggingface.co/Qwen)\* | 3B/7B | qwen2_omni | +| [Qwen2.5-Omni](https://huggingface.co/Qwen) | 3B/7B | qwen2_omni | | [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen) | 2B/3B/7B/32B/72B | qwen2_vl | | [Seed Coder](https://huggingface.co/ByteDance-Seed) | 8B | seed_coder | | [Skywork o1](https://huggingface.co/Skywork) | 8B | skywork_o1 | @@ -441,6 +441,7 @@ huggingface-cli login | ------------ | ------- | --------- | | python | 3.9 | 3.10 | | torch | 2.0.0 | 2.6.0 | +| torchvision | 0.15.0 | 0.21.0 | | transformers | 4.45.0 | 4.50.0 | | datasets | 2.16.0 | 3.2.0 | | accelerate | 0.34.0 | 1.2.1 | diff --git a/data/README.md b/data/README.md index 90503351..2ed3782a 100644 --- a/data/README.md +++ b/data/README.md @@ -89,7 +89,9 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh ``` > [!TIP] -> If the model has reasoning capabilities but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True`, the empty CoT will be added to the model responses and loss computation will be considered; otherwise, it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference. +> If the model has reasoning capabilities but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True` (slow thinking), the empty CoT will be added to the model responses and loss computation will be considered; otherwise (fast thinking), it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference. +> +> If you want to train data containing CoT with slow thinking and data without CoT with fast thinking, you can set `enable_thinking` to `None`. However, this feature is relatively complicated and should be used with caution. ### Pre-training Dataset diff --git a/data/README_zh.md b/data/README_zh.md index f26725ca..e63eb73c 100644 --- a/data/README_zh.md +++ b/data/README_zh.md @@ -88,7 +88,9 @@ ``` > [!TIP] -> 如果模型本身具备推理能力,而数据集不包含思维链,LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking` 为 `True` 时,空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失。请在训练和推理时保持 `enable_thinking` 参数一致。 +> 如果模型本身具备推理能力,而数据集不包含思维链,LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking` 为 `True` 时(慢思考),空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失(快思考)。请在训练和推理时保持 `enable_thinking` 参数一致。 +> +> 如果您希望训练包含思维链的数据时使用慢思考,训练不包含思维链的数据时使用快思考,可以设置 `enable_thinking` 为 `None`。但该功能较为复杂,请谨慎使用。 ### 预训练数据集 diff --git a/requirements.txt b/requirements.txt index 484ebae0..1faf1c07 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0 +transformers>=4.45.0,<=4.52.1,!=4.46.*,!=4.47.*,!=4.48.0,!=4.52.0 datasets>=2.16.0,<=3.6.0 accelerate>=0.34.0,<=1.7.0 peft>=0.14.0,<=0.15.2 diff --git a/setup.py b/setup.py index 4c29ca87..6066a957 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_console_scripts() -> list[str]: extra_require = { - "torch": ["torch>=1.13.1"], + "torch": ["torch>=2.0.0", "torchvision>=0.15.0"], "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"], "metrics": ["nltk", "jieba", "rouge-chinese"], "deepspeed": ["deepspeed>=0.10.0,<=0.16.5"], diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index a237b907..f8598212 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -57,19 +57,11 @@ if is_transformers_version_greater_than("4.45.0"): ) -if is_transformers_version_greater_than("4.49.0"): - try: - from transformers.image_utils import make_batched_videos, make_flat_list_of_images - except ImportError: - try: - # If that fails, try importing from the new location - from transformers.image_utils import make_flat_list_of_images - from transformers.video_utils import make_batched_videos - except ImportError: - raise ImportError( - "Could not import make_batched_videos and make_flat_list_of_images. " - "In Transformers 4.52.0, make_batched_videos will be moved to transformers.video_utils." - ) +if is_transformers_version_greater_than("4.52.0"): + from transformers.image_utils import make_flat_list_of_images + from transformers.video_utils import make_batched_videos +elif is_transformers_version_greater_than("4.49.0"): + from transformers.image_utils import make_batched_videos, make_flat_list_of_images if TYPE_CHECKING: diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index aa422cab..74faea46 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -52,7 +52,7 @@ class Template: efficient_eos: bool replace_eos: bool replace_jinja_template: bool - enable_thinking: bool + enable_thinking: Optional[bool] mm_plugin: "BasePlugin" def encode_oneturn( @@ -411,14 +411,17 @@ class ReasoningTemplate(Template): for i in range(1, len(messages) - 2, 2): messages[i]["content"] = self.remove_thought(messages[i]["content"]) + if self.enable_thinking is False: # remove all cot + messages[-1]["content"] = self.remove_thought(messages[-1]["content"]) + prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools) if ( self.thought_words[0] not in messages[-1]["content"] and self.thought_words[1] not in messages[-1]["content"] - ): - if not self.enable_thinking: - prompt_ids = prompt_ids + self.get_thought_word_ids(tokenizer) - else: + ): # add empty cot + if not self.enable_thinking: # do not compute loss + prompt_ids += self.get_thought_word_ids(tokenizer) + else: # do compute loss response_ids = self.get_thought_word_ids(tokenizer) + response_ids return prompt_ids, response_ids @@ -431,15 +434,20 @@ class ReasoningTemplate(Template): system: Optional[str] = None, tools: Optional[str] = None, ) -> list[tuple[list[int], list[int]]]: + messages = deepcopy(messages) + if self.enable_thinking is False: # remove all cot + for i in range(1, len(messages), 2): + messages[i]["content"] = self.remove_thought(messages[i]["content"]) + encoded_messages = self._encode(tokenizer, messages, system, tools) for i in range(0, len(messages), 2): if ( self.thought_words[0] not in messages[i + 1]["content"] and self.thought_words[1] not in messages[i + 1]["content"] - ): - if not self.enable_thinking: + ): # add empty cot + if not self.enable_thinking: # do not compute loss encoded_messages[i] += self.get_thought_word_ids(tokenizer) - else: + else: # do compute loss encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1] return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] @@ -463,7 +471,7 @@ def register_template( efficient_eos: bool = False, replace_eos: bool = False, replace_jinja_template: bool = False, - enable_thinking: bool = True, + enable_thinking: Optional[bool] = True, mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), template_class: type["Template"] = Template, ) -> None: diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index a461aeee..5dc7c3f3 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -2566,6 +2566,14 @@ register_model_group( DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B", DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B", }, + "Qwen2.5-Omni-7B-GPTQ-Int4": { + DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4", + DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4", + }, + "Qwen2.5-Omni-7B-AWQ": { + DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-AWQ", + DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-AWQ", + }, }, template="qwen2_omni", multimodal=True, diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index dcc22c1b..1ad4c1cf 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -94,7 +94,9 @@ def check_version(requirement: str, mandatory: bool = False) -> None: def check_dependencies() -> None: r"""Check the version of the required packages.""" - check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0") + check_version( + "transformers>=4.45.0,<=4.52.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0" + ) check_version("datasets>=2.16.0,<=3.6.0") check_version("accelerate>=0.34.0,<=1.7.0") check_version("peft>=0.14.0,<=0.15.2") diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 588b8c5c..c84fb2f7 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -119,7 +119,7 @@ class DataArguments: default=None, metadata={"help": "Override the default system message in the template."}, ) - enable_thinking: bool = field( + enable_thinking: Optional[bool] = field( default=True, metadata={"help": "Whether or not to enable thinking mode for reasoning models."}, ) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index eec9ceca..d2f7cc52 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -235,10 +235,6 @@ class ProcessorArguments: default=False, metadata={"help": "Whether to crop the image to patches for internvl."}, ) - use_audio_in_video: bool = field( - default=False, - metadata={"help": "Whether or not to use audio in video inputs."}, - ) video_max_pixels: int = field( default=256 * 256, metadata={"help": "The maximum number of pixels of video inputs."}, @@ -255,6 +251,10 @@ class ProcessorArguments: default=128, metadata={"help": "The maximum number of sampled frames for video inputs."}, ) + use_audio_in_video: bool = field( + default=False, + metadata={"help": "Whether or not to use audio in video inputs."}, + ) audio_sampling_rate: int = field( default=16000, metadata={"help": "The sampling rate of audio inputs."}, diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 901010c7..247de48b 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -24,6 +24,7 @@ import transformers.models from transformers.activations import ACT2FN from ...extras import logging +from ...extras.packages import is_transformers_version_greater_than if TYPE_CHECKING: @@ -281,7 +282,7 @@ _register_composite_model( model_type="qwen2_vl", projector_key="visual.merger", vision_model_keys=["visual.patch_embed", "visual.blocks"], - language_model_keys=["model", "lm_head"], + language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"], lora_conflict_keys=["patch_embed"], ) @@ -290,6 +291,6 @@ _register_composite_model( model_type="qwen2_5_vl", projector_key="visual.merger", vision_model_keys=["visual.patch_embed", "visual.blocks"], - language_model_keys=["model", "lm_head"], + language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"], lora_conflict_keys=["patch_embed"], ) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index cedcf9da..20228812 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -85,8 +85,8 @@ def patch_processor( setattr(processor, "video_min_pixels", model_args.video_min_pixels) setattr(processor, "video_fps", model_args.video_fps) setattr(processor, "video_maxlen", model_args.video_maxlen) - setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate) setattr(processor, "use_audio_in_video", model_args.use_audio_in_video) + setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate) def patch_config( diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 80f67c6c..2539127c 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer): return super().create_scheduler(num_training_steps, optimizer) @override - def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]: if self.finetuning_args.disable_shuffling: return torch.utils.data.SequentialSampler(self.train_dataset) - return super()._get_train_sampler() + return super()._get_train_sampler(*args, **kwargs) @override def get_batch_samples(self, *args, **kwargs): diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 0323041f..f67d0ece 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge if TYPE_CHECKING: - import torch.utils.data from transformers import PreTrainedModel, ProcessorMixin from ...hparams import FinetuningArguments @@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer): return super().create_scheduler(num_training_steps, optimizer) @override - def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]: r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler.""" if self.finetuning_args.disable_shuffling: return torch.utils.data.SequentialSampler(self.train_dataset) - return Trainer._get_train_sampler(self) + return Trainer._get_train_sampler(self, *args, **kwargs) @override def get_batch_samples(self, *args, **kwargs): diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 8495bbb2..096cbf68 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -70,11 +70,11 @@ class CustomTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) @override - def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]: if self.finetuning_args.disable_shuffling: return torch.utils.data.SequentialSampler(self.train_dataset) - return super()._get_train_sampler() + return super()._get_train_sampler(*args, **kwargs) @override def compute_loss(self, model, inputs, *args, **kwargs): diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 8c14b0ab..fe2bd557 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer): return super().create_scheduler(num_training_steps, optimizer) @override - def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]: if self.finetuning_args.disable_shuffling: return torch.utils.data.SequentialSampler(self.train_dataset) - return super()._get_train_sampler() + return super()._get_train_sampler(*args, **kwargs) @override def compute_loss( diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index fece1bd9..d0b8d05b 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): return super().create_scheduler(num_training_steps, optimizer) @override - def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]: if self.finetuning_args.disable_shuffling: return torch.utils.data.SequentialSampler(self.train_dataset) - return super()._get_train_sampler() + return super()._get_train_sampler(*args, **kwargs) @override def compute_loss(self, model, inputs, *args, **kwargs): diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 2387174a..887f2517 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -205,6 +205,14 @@ def load_eval_results(path: os.PathLike) -> str: return f"```json\n{result}\n```\n" +def calculate_pixels(pixels: str) -> int: + r"""Calculate the number of pixels from the expression.""" + if "*" in pixels: + return int(pixels.split("*")[0]) * int(pixels.split("*")[1]) + else: + return int(pixels) + + def create_ds_config() -> None: r"""Create deepspeed config in the current directory.""" os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index 7ca99647..8b7aa6e9 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: use_llama_pro = gr.Checkbox() with gr.Column(): + enable_thinking = gr.Checkbox(value=True) report_to = gr.Dropdown( - choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"], - value=["none"], + choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"], + value="none", allow_custom_value=True, - multiselect=True, ) input_elems.update( @@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: mask_history, resize_vocab, use_llama_pro, + enable_thinking, report_to, } ) @@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: mask_history=mask_history, resize_vocab=resize_vocab, use_llama_pro=use_llama_pro, + enable_thinking=enable_thinking, report_to=report_to, ) ) @@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]: ) ) + with gr.Accordion(open=False) as mm_tab: + with gr.Row(): + freeze_vision_tower = gr.Checkbox(value=True) + freeze_multi_modal_projector = gr.Checkbox(value=True) + freeze_language_model = gr.Checkbox(value=False) + + with gr.Row(): + image_max_pixels = gr.Textbox(value="768*768") + image_min_pixels = gr.Textbox(value="32*32") + video_max_pixels = gr.Textbox(value="256*256") + video_min_pixels = gr.Textbox(value="16*16") + + input_elems.update( + { + freeze_vision_tower, + freeze_multi_modal_projector, + freeze_language_model, + image_max_pixels, + image_min_pixels, + video_max_pixels, + video_min_pixels, + } + ) + elem_dict.update( + dict( + mm_tab=mm_tab, + freeze_vision_tower=freeze_vision_tower, + freeze_multi_modal_projector=freeze_multi_modal_projector, + freeze_language_model=freeze_language_model, + image_max_pixels=image_max_pixels, + image_min_pixels=image_min_pixels, + video_max_pixels=video_max_pixels, + video_min_pixels=video_min_pixels, + ) + ) + with gr.Accordion(open=False) as galore_tab: with gr.Row(): use_galore = gr.Checkbox() diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index ed05bae7..17ffb0a4 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -871,6 +871,28 @@ LOCALES = { "info": "拡張ブロックのパラメータのみをトレーニングします。", }, }, + "enable_thinking": { + "en": { + "label": "Enable thinking", + "info": "Whether or not to enable thinking mode for reasoning models.", + }, + "ru": { + "label": "Включить мысли", + "info": "Включить режим мысли для моделей решающего характера.", + }, + "zh": { + "label": "启用思考模式", + "info": "是否启用推理模型的思考模式。", + }, + "ko": { + "label": "생각 모드 활성화", + "info": "추론 모델의 생각 모드를 활성화할지 여부.", + }, + "ja": { + "label": "思考モードを有効化", + "info": "推論モデルの思考モードを有効にするかどうか。", + }, + }, "report_to": { "en": { "label": "Enable external logger", @@ -1374,6 +1396,177 @@ LOCALES = { "info": "PPO トレーニングにおいて報酬スコアをホワイトニング処理します。", }, }, + "mm_tab": { + "en": { + "label": "Multimodal configurations", + }, + "ru": { + "label": "Конфигурации мультимедиа", + }, + "zh": { + "label": "多模态参数设置", + }, + "ko": { + "label": "멀티모달 구성", + }, + "ja": { + "label": "多モーダル設定", + }, + }, + "freeze_vision_tower": { + "en": { + "label": "Freeze vision tower", + "info": "Freeze the vision tower in the model.", + }, + "ru": { + "label": "Заморозить башню визиона", + "info": "Заморозить башню визиона в модели.", + }, + "zh": { + "label": "冻结视觉编码器", + "info": "冻结模型中的视觉编码器。", + }, + "ko": { + "label": "비전 타워 고정", + "info": "모델의 비전 타워를 고정합니다.", + }, + "ja": { + "label": "ビジョンタワーの固定", + "info": "モデルのビジョンタワーを固定します。", + }, + }, + "freeze_multi_modal_projector": { + "en": { + "label": "Freeze multi-modal projector", + "info": "Freeze the multi-modal projector in the model.", + }, + "ru": { + "label": "Заморозить мультимодальный проектор", + "info": "Заморозить мультимодальный проектор в модели.", + }, + "zh": { + "label": "冻结多模态投影器", + "info": "冻结模型中的多模态投影器。", + }, + "ko": { + "label": "멀티모달 프로젝터 고정", + "info": "모델의 멀티모달 프로젝터를 고정합니다.", + }, + "ja": { + "label": "多モーダルプロジェクターの固定", + "info": "モデルの多モーダルプロジェクターを固定します。", + }, + }, + "freeze_language_model": { + "en": { + "label": "Freeze language model", + "info": "Freeze the language model in the model.", + }, + "ru": { + "label": "Заморозить язык модели", + "info": "Заморозить язык модели в модели.", + }, + "zh": { + "label": "冻结语言模型", + "info": "冻结模型中的语言模型。", + }, + "ko": { + "label": "언어 모델 고정", + "info": "모델의 언어 모델을 고정합니다.", + }, + "ja": { + "label": "言語モデルの固定", + "info": "モデルの言語モデルを固定します。", + }, + }, + "image_max_pixels": { + "en": { + "label": "Image max pixels", + "info": "The maximum number of pixels of image inputs.", + }, + "ru": { + "label": "Максимальное количество пикселей изображения", + "info": "Максимальное количество пикселей изображения.", + }, + "zh": { + "label": "图像最大像素", + "info": "输入图像的最大像素数。", + }, + "ko": { + "label": "이미지 최대 픽셀", + "info": "이미지 입력의 최대 픽셀 수입니다.", + }, + "ja": { + "label": "画像最大ピクセル", + "info": "画像入力の最大ピクセル数です。", + }, + }, + "image_min_pixels": { + "en": { + "label": "Image min pixels", + "info": "The minimum number of pixels of image inputs.", + }, + "ru": { + "label": "Минимальное количество пикселей изображения", + "info": "Минимальное количество пикселей изображения.", + }, + "zh": { + "label": "图像最小像素", + "info": "输入图像的最小像素数。", + }, + "ko": { + "label": "이미지 최소 픽셀", + "info": "이미지 입력의 최소 픽셀 수입니다.", + }, + "ja": { + "label": "画像最小ピクセル", + "info": "画像入力の最小ピクセル数です。", + }, + }, + "video_max_pixels": { + "en": { + "label": "Video max pixels", + "info": "The maximum number of pixels of video inputs.", + }, + "ru": { + "label": "Максимальное количество пикселей видео", + "info": "Максимальное количество пикселей видео.", + }, + "zh": { + "label": "视频最大像素", + "info": "输入视频的最大像素数。", + }, + "ko": { + "label": "비디오 최대 픽셀", + "info": "비디오 입력의 최대 픽셀 수입니다.", + }, + "ja": { + "label": "ビデオ最大ピクセル", + "info": "ビデオ入力の最大ピクセル数です。", + }, + }, + "video_min_pixels": { + "en": { + "label": "Video min pixels", + "info": "The minimum number of pixels of video inputs.", + }, + "ru": { + "label": "Минимальное количество пикселей видео", + "info": "Минимальное количество пикселей видео.", + }, + "zh": { + "label": "视频最小像素", + "info": "输入视频的最小像素数。", + }, + "ko": { + "label": "비디오 최소 픽셀", + "info": "비디오 입력의 최소 픽셀 수입니다.", + }, + "ja": { + "label": "ビデオ最小ピクセル", + "info": "ビデオ入力の最小ピクセル数です。", + }, + }, "galore_tab": { "en": { "label": "GaLore configurations", @@ -2468,23 +2661,6 @@ LOCALES = { "label": "HTML タグをエスケープ", }, }, - "enable_thinking": { - "en": { - "label": "Enable thinking", - }, - "ru": { - "label": "Включить мышление", - }, - "zh": { - "label": "启用思考", - }, - "ko": { - "label": "사고를 활성화하다", - }, - "ja": { - "label": "思考を可能にする", - }, - }, "clear_btn": { "en": { "value": "Clear history", diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 4fbebde1..3715974a 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -29,6 +29,7 @@ from .common import ( DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, abort_process, + calculate_pixels, gen_cmd, get_save_dir, load_args, @@ -162,7 +163,15 @@ class Runner: mask_history=get("train.mask_history"), resize_vocab=get("train.resize_vocab"), use_llama_pro=get("train.use_llama_pro"), + enable_thinking=get("train.enable_thinking"), report_to=get("train.report_to"), + freeze_vision_tower=get("train.freeze_vision_tower"), + freeze_multi_modal_projector=get("train.freeze_multi_modal_projector"), + freeze_language_model=get("train.freeze_language_model"), + image_max_pixels=calculate_pixels(get("train.image_max_pixels")), + image_min_pixels=calculate_pixels(get("train.image_min_pixels")), + video_max_pixels=calculate_pixels(get("train.video_max_pixels")), + video_min_pixels=calculate_pixels(get("train.video_min_pixels")), use_galore=get("train.use_galore"), use_apollo=get("train.use_apollo"), use_badam=get("train.use_badam"), @@ -256,12 +265,6 @@ class Runner: args["badam_switch_interval"] = get("train.badam_switch_interval") args["badam_update_ratio"] = get("train.badam_update_ratio") - # report_to - if "none" in args["report_to"]: - args["report_to"] = "none" - elif "all" in args["report_to"]: - args["report_to"] = "all" - # swanlab config if get("train.use_swanlab"): args["swanlab_project"] = get("train.swanlab_project") diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index a19de12a..6910cbad 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -135,8 +135,7 @@ def _check_plugin( expected_mm_inputs: dict[str, Any] = {}, expected_no_mm_inputs: dict[str, Any] = {}, ) -> None: - # test omni_messages - if plugin.__class__.__name__ == "Qwen2OmniPlugin": + if plugin.__class__.__name__ == "Qwen2OmniPlugin": # test omni_messages assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == ( expected_input_ids, @@ -146,8 +145,7 @@ def _check_plugin( plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor), expected_mm_inputs, ) - # test mm_messages - if plugin.__class__.__name__ != "BasePlugin": + elif plugin.__class__.__name__ != "BasePlugin": # test mm_messages assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == ( expected_input_ids, @@ -201,7 +199,7 @@ def test_gemma3_plugin(): _check_plugin(**check_inputs) -@pytest.mark.xfail(reason="Unknown error.") +@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") def test_internvl_plugin(): image_seqlen = 256 tokenizer_module = _load_tokenizer_module(model_name_or_path="OpenGVLab/InternVL3-1B-hf") @@ -219,7 +217,7 @@ def test_internvl_plugin(): _check_plugin(**check_inputs) -@pytest.mark.xfail(reason="Unknown error.") +@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0") def test_llama4_plugin(): tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4) processor = tokenizer_module["processor"] @@ -321,10 +319,9 @@ def test_pixtral_plugin(): _check_plugin(**check_inputs) -@pytest.mark.xfail(reason="Unknown error.") +@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0") def test_qwen2_omni_plugin(): - image_seqlen = 4 - audio_seqlen = 2 + image_seqlen, audio_seqlen = 4, 2 tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B") qwen2_omni_plugin = get_mm_plugin( name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>" diff --git a/tests/data/test_template.py b/tests/data/test_template.py index eef52efe..1f04e2ee 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -127,20 +127,21 @@ def test_encode_multiturn(use_fast: bool): @pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False]) -@pytest.mark.parametrize("enable_thinking", [True, False]) +@pytest.mark.parametrize("enable_thinking", [True, False, None]) def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool): - messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES + input_messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking) template = get_template_and_fix_tokenizer(tokenizer, data_args) - prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages) + prompt_ids, answer_ids = template.encode_oneturn(tokenizer, input_messages) + output_messages = MESSAGES if enable_thinking is False else input_messages prompt_str = ( - f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n" + f"<|im_start|>user\n{output_messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n" f"{MESSAGES[1]['content']}<|im_end|>\n" - f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n" + f"<|im_start|>user\n{output_messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n" ) - answer_str = f"{messages[3]['content']}<|im_end|>\n" - if not cot_messages: + answer_str = f"{output_messages[3]['content']}<|im_end|>\n" + if not cot_messages or enable_thinking is False: if enable_thinking: answer_str = "\n\n\n\n" + answer_str else: @@ -151,18 +152,19 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi @pytest.mark.parametrize("use_fast", [True, False]) @pytest.mark.parametrize("cot_messages", [True, False]) -@pytest.mark.parametrize("enable_thinking", [True, False]) +@pytest.mark.parametrize("enable_thinking", [True, False, None]) def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool): - messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES + input_messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast) data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking) template = get_template_and_fix_tokenizer(tokenizer, data_args) - encoded_pairs = template.encode_multiturn(tokenizer, messages) - prompt_str_1 = f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n" - answer_str_1 = f"{messages[1]['content']}<|im_end|>\n" - prompt_str_2 = f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n" - answer_str_2 = f"{messages[3]['content']}<|im_end|>\n" - if not cot_messages: + encoded_pairs = template.encode_multiturn(tokenizer, input_messages) + output_messages = MESSAGES if enable_thinking is False else input_messages + prompt_str_1 = f"<|im_start|>user\n{output_messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n" + answer_str_1 = f"{output_messages[1]['content']}<|im_end|>\n" + prompt_str_2 = f"<|im_start|>user\n{output_messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n" + answer_str_2 = f"{output_messages[3]['content']}<|im_end|>\n" + if not cot_messages or enable_thinking is False: if enable_thinking: answer_str_1 = "\n\n\n\n" + answer_str_1 answer_str_2 = "\n\n\n\n" + answer_str_2 diff --git a/tests/model/model_utils/test_visual.py b/tests/model/model_utils/test_visual.py index 44abe349..4bba46d6 100644 --- a/tests/model/model_utils/test_visual.py +++ b/tests/model/model_utils/test_visual.py @@ -16,6 +16,7 @@ import pytest import torch from transformers import AutoConfig, AutoModelForVision2Seq +from llamafactory.extras.packages import is_transformers_version_greater_than from llamafactory.hparams import FinetuningArguments, ModelArguments from llamafactory.model.adapter import init_adapter @@ -45,10 +46,12 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo assert param.requires_grad != freeze_language_model -@pytest.mark.parametrize("freeze_vision_tower", (False, True)) -def test_visual_lora(freeze_vision_tower: bool): +@pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False))) +def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool): model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") - finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower) + finetuning_args = FinetuningArguments( + finetuning_type="lora", freeze_vision_tower=freeze_vision_tower, freeze_language_model=freeze_language_model + ) config = AutoConfig.from_pretrained(model_args.model_name_or_path) with torch.device("meta"): model = AutoModelForVision2Seq.from_config(config) @@ -61,10 +64,15 @@ def test_visual_lora(freeze_vision_tower: bool): else: frozen_params.add(name) - if freeze_vision_tower: - assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" not in trainable_params + if is_transformers_version_greater_than("4.52.0"): + visual_param_name = "base_model.model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" + language_param_name = "base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.default.weight" + merger_param_name = "base_model.model.model.visual.merger.lora_A.default.weight" else: - assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" in trainable_params + visual_param_name = "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" + language_param_name = "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight" + merger_param_name = "base_model.model.visual.merger.lora_A.default.weight" - assert "merger" not in trainable_params - assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight" in trainable_params + assert (visual_param_name in trainable_params) != freeze_vision_tower + assert (language_param_name in trainable_params) != freeze_language_model + assert (merger_param_name in trainable_params) is False diff --git a/tests/version.txt b/tests/version.txt index bbb71ad5..dae5ebba 100644 --- a/tests/version.txt +++ b/tests/version.txt @@ -1,2 +1,2 @@ # change if test fails or cache is outdated -0.9.3.106 +0.9.3.107