mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[deps] update to transformers 4.52 (#8125)
This commit is contained in:
		
							parent
							
								
									56926d76f9
								
							
						
					
					
						commit
						9ae17cd173
					
				
							
								
								
									
										8
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -40,6 +40,9 @@ jobs:
 | 
			
		||||
          - python: "3.9"
 | 
			
		||||
            os: "ubuntu-latest"
 | 
			
		||||
            transformers: "4.49.0"
 | 
			
		||||
          - python: "3.9"
 | 
			
		||||
            os: "ubuntu-latest"
 | 
			
		||||
            transformers: "4.51.0"
 | 
			
		||||
 | 
			
		||||
    runs-on: ${{ matrix.os }}
 | 
			
		||||
 | 
			
		||||
@ -72,6 +75,11 @@ jobs:
 | 
			
		||||
        run: |
 | 
			
		||||
          python -m pip install "transformers==${{ matrix.transformers }}"
 | 
			
		||||
 | 
			
		||||
      - name: Downgrade transformers
 | 
			
		||||
        if: ${{ matrix.os == 'macos-13' }}
 | 
			
		||||
        run: |
 | 
			
		||||
          python -m pip install "transformers<4.52.0"
 | 
			
		||||
 | 
			
		||||
      - name: Cache files
 | 
			
		||||
        id: hf-hub-cache
 | 
			
		||||
        uses: actions/cache@v4
 | 
			
		||||
 | 
			
		||||
@ -266,7 +266,7 @@ Choose your path:
 | 
			
		||||
| [Hunyuan](https://huggingface.co/tencent/)                        | 7B                               | hunyuan             |
 | 
			
		||||
| [Index](https://huggingface.co/IndexTeam)                         | 1.9B                             | index               |
 | 
			
		||||
| [InternLM 2-3](https://huggingface.co/internlm)                   | 7B/8B/20B                        | intern2             |
 | 
			
		||||
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\*              | 1B/2B/8B/14B/38B/78B             | intern_vl           |
 | 
			
		||||
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)                | 1B/2B/8B/14B/38B/78B             | intern_vl           |
 | 
			
		||||
| [Kimi-VL](https://huggingface.co/moonshotai)                      | 16B                              | kimi_vl             |
 | 
			
		||||
| [Llama](https://github.com/facebookresearch/llama)                | 7B/13B/33B/65B                   | -                   |
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2              |
 | 
			
		||||
@ -292,7 +292,7 @@ Choose your path:
 | 
			
		||||
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen)   | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen                |
 | 
			
		||||
| [Qwen3 (MoE)](https://huggingface.co/Qwen)                        | 0.6B/1.7B/4B/8B/14B/32B/235B     | qwen3               |
 | 
			
		||||
| [Qwen2-Audio](https://huggingface.co/Qwen)                        | 7B                               | qwen2_audio         |
 | 
			
		||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)\*                     | 3B/7B                            | qwen2_omni          |
 | 
			
		||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)                       | 3B/7B                            | qwen2_omni          |
 | 
			
		||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen)            | 2B/3B/7B/32B/72B                 | qwen2_vl            |
 | 
			
		||||
| [Seed Coder](https://huggingface.co/ByteDance-Seed)               | 8B                               | seed_coder          |
 | 
			
		||||
| [Skywork o1](https://huggingface.co/Skywork)                      | 8B                               | skywork_o1          |
 | 
			
		||||
@ -439,6 +439,7 @@ huggingface-cli login
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.9     | 3.10      |
 | 
			
		||||
| torch        | 2.0.0   | 2.6.0     |
 | 
			
		||||
| torchvision  | 0.15.0  | 0.21.0    |
 | 
			
		||||
| transformers | 4.45.0  | 4.50.0    |
 | 
			
		||||
| datasets     | 2.16.0  | 3.2.0     |
 | 
			
		||||
| accelerate   | 0.34.0  | 1.2.1     |
 | 
			
		||||
 | 
			
		||||
@ -268,7 +268,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
| [Hunyuan](https://huggingface.co/tencent/)                        | 7B                               | hunyuan             |
 | 
			
		||||
| [Index](https://huggingface.co/IndexTeam)                         | 1.9B                             | index               |
 | 
			
		||||
| [InternLM 2-3](https://huggingface.co/internlm)                   | 7B/8B/20B                        | intern2             |
 | 
			
		||||
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)\*              | 1B/2B/8B/14B/38B/78B             | intern_vl           |
 | 
			
		||||
| [InternVL 2.5-3](https://huggingface.co/OpenGVLab)                | 1B/2B/8B/14B/38B/78B             | intern_vl           |
 | 
			
		||||
| [Kimi-VL](https://huggingface.co/moonshotai)                      | 16B                              | kimi_vl             |
 | 
			
		||||
| [Llama](https://github.com/facebookresearch/llama)                | 7B/13B/33B/65B                   | -                   |
 | 
			
		||||
| [Llama 2](https://huggingface.co/meta-llama)                      | 7B/13B/70B                       | llama2              |
 | 
			
		||||
@ -294,7 +294,7 @@ https://github.com/user-attachments/assets/43b700c6-a178-41db-b1f8-8190a5d3fcfc
 | 
			
		||||
| [Qwen (1-2.5) (Code/Math/MoE/QwQ)](https://huggingface.co/Qwen)   | 0.5B/1.5B/3B/7B/14B/32B/72B/110B | qwen                |
 | 
			
		||||
| [Qwen3 (MoE)](https://huggingface.co/Qwen)                        | 0.6B/1.7B/4B/8B/14B/32B/235B     | qwen3               |
 | 
			
		||||
| [Qwen2-Audio](https://huggingface.co/Qwen)                        | 7B                               | qwen2_audio         |
 | 
			
		||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)\*                     | 3B/7B                            | qwen2_omni          |
 | 
			
		||||
| [Qwen2.5-Omni](https://huggingface.co/Qwen)                       | 3B/7B                            | qwen2_omni          |
 | 
			
		||||
| [Qwen2-VL/Qwen2.5-VL/QVQ](https://huggingface.co/Qwen)            | 2B/3B/7B/32B/72B                 | qwen2_vl            |
 | 
			
		||||
| [Seed Coder](https://huggingface.co/ByteDance-Seed)               | 8B                               | seed_coder          |
 | 
			
		||||
| [Skywork o1](https://huggingface.co/Skywork)                      | 8B                               | skywork_o1          |
 | 
			
		||||
@ -441,6 +441,7 @@ huggingface-cli login
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.9     | 3.10      |
 | 
			
		||||
| torch        | 2.0.0   | 2.6.0     |
 | 
			
		||||
| torchvision  | 0.15.0  | 0.21.0    |
 | 
			
		||||
| transformers | 4.45.0  | 4.50.0    |
 | 
			
		||||
| datasets     | 2.16.0  | 3.2.0     |
 | 
			
		||||
| accelerate   | 0.34.0  | 1.2.1     |
 | 
			
		||||
 | 
			
		||||
@ -89,7 +89,9 @@ Regarding the above dataset, the *dataset description* in `dataset_info.json` sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
> [!TIP]  
 | 
			
		||||
> If the model has reasoning capabilities but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True`, the empty CoT will be added to the model responses and loss computation will be considered; otherwise, it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference.
 | 
			
		||||
> If the model has reasoning capabilities but the dataset does not contain chain-of-thought (CoT), LLaMA-Factory will automatically add empty CoT to the data. When `enable_thinking` is `True` (slow thinking), the empty CoT will be added to the model responses and loss computation will be considered; otherwise (fast thinking), it will be added to the user prompts and loss computation will be ignored. Please keep the `enable_thinking` parameter consistent during training and inference.
 | 
			
		||||
>
 | 
			
		||||
> If you want to train data containing CoT with slow thinking and data without CoT with fast thinking, you can set `enable_thinking` to `None`. However, this feature is relatively complicated and should be used with caution.
 | 
			
		||||
 | 
			
		||||
### Pre-training Dataset
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -88,7 +88,9 @@
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> 如果模型本身具备推理能力,而数据集不包含思维链,LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking` 为 `True` 时,空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失。请在训练和推理时保持 `enable_thinking` 参数一致。
 | 
			
		||||
> 如果模型本身具备推理能力,而数据集不包含思维链,LLaMA-Factory 会自动为数据添加空思维链。当 `enable_thinking` 为 `True` 时(慢思考),空思维链会添加到模型回答中并且计算损失,否则会添加到用户指令中并且不计算损失(快思考)。请在训练和推理时保持 `enable_thinking` 参数一致。
 | 
			
		||||
>
 | 
			
		||||
> 如果您希望训练包含思维链的数据时使用慢思考,训练不包含思维链的数据时使用快思考,可以设置 `enable_thinking` 为 `None`。但该功能较为复杂,请谨慎使用。
 | 
			
		||||
 | 
			
		||||
### 预训练数据集
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
transformers>=4.45.0,<=4.51.3,!=4.46.*,!=4.47.*,!=4.48.0
 | 
			
		||||
transformers>=4.45.0,<=4.52.1,!=4.46.*,!=4.47.*,!=4.48.0,!=4.52.0
 | 
			
		||||
datasets>=2.16.0,<=3.6.0
 | 
			
		||||
accelerate>=0.34.0,<=1.7.0
 | 
			
		||||
peft>=0.14.0,<=0.15.2
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -42,7 +42,7 @@ def get_console_scripts() -> list[str]:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
extra_require = {
 | 
			
		||||
    "torch": ["torch>=1.13.1"],
 | 
			
		||||
    "torch": ["torch>=2.0.0", "torchvision>=0.15.0"],
 | 
			
		||||
    "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
 | 
			
		||||
    "metrics": ["nltk", "jieba", "rouge-chinese"],
 | 
			
		||||
    "deepspeed": ["deepspeed>=0.10.0,<=0.16.5"],
 | 
			
		||||
 | 
			
		||||
@ -57,19 +57,11 @@ if is_transformers_version_greater_than("4.45.0"):
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if is_transformers_version_greater_than("4.49.0"):
 | 
			
		||||
    try:
 | 
			
		||||
        from transformers.image_utils import make_batched_videos, make_flat_list_of_images
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        try:
 | 
			
		||||
            # If that fails, try importing from the new location
 | 
			
		||||
            from transformers.image_utils import make_flat_list_of_images
 | 
			
		||||
            from transformers.video_utils import make_batched_videos
 | 
			
		||||
        except ImportError:
 | 
			
		||||
            raise ImportError(
 | 
			
		||||
                "Could not import make_batched_videos and make_flat_list_of_images. "
 | 
			
		||||
                "In Transformers 4.52.0, make_batched_videos will be moved to transformers.video_utils."
 | 
			
		||||
            )
 | 
			
		||||
if is_transformers_version_greater_than("4.52.0"):
 | 
			
		||||
    from transformers.image_utils import make_flat_list_of_images
 | 
			
		||||
    from transformers.video_utils import make_batched_videos
 | 
			
		||||
elif is_transformers_version_greater_than("4.49.0"):
 | 
			
		||||
    from transformers.image_utils import make_batched_videos, make_flat_list_of_images
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ class Template:
 | 
			
		||||
    efficient_eos: bool
 | 
			
		||||
    replace_eos: bool
 | 
			
		||||
    replace_jinja_template: bool
 | 
			
		||||
    enable_thinking: bool
 | 
			
		||||
    enable_thinking: Optional[bool]
 | 
			
		||||
    mm_plugin: "BasePlugin"
 | 
			
		||||
 | 
			
		||||
    def encode_oneturn(
 | 
			
		||||
@ -411,14 +411,17 @@ class ReasoningTemplate(Template):
 | 
			
		||||
        for i in range(1, len(messages) - 2, 2):
 | 
			
		||||
            messages[i]["content"] = self.remove_thought(messages[i]["content"])
 | 
			
		||||
 | 
			
		||||
        if self.enable_thinking is False:  # remove all cot
 | 
			
		||||
            messages[-1]["content"] = self.remove_thought(messages[-1]["content"])
 | 
			
		||||
 | 
			
		||||
        prompt_ids, response_ids = super().encode_oneturn(tokenizer, messages, system, tools)
 | 
			
		||||
        if (
 | 
			
		||||
            self.thought_words[0] not in messages[-1]["content"]
 | 
			
		||||
            and self.thought_words[1] not in messages[-1]["content"]
 | 
			
		||||
        ):
 | 
			
		||||
            if not self.enable_thinking:
 | 
			
		||||
                prompt_ids = prompt_ids + self.get_thought_word_ids(tokenizer)
 | 
			
		||||
            else:
 | 
			
		||||
        ):  # add empty cot
 | 
			
		||||
            if not self.enable_thinking:  # do not compute loss
 | 
			
		||||
                prompt_ids += self.get_thought_word_ids(tokenizer)
 | 
			
		||||
            else:  # do compute loss
 | 
			
		||||
                response_ids = self.get_thought_word_ids(tokenizer) + response_ids
 | 
			
		||||
 | 
			
		||||
        return prompt_ids, response_ids
 | 
			
		||||
@ -431,15 +434,20 @@ class ReasoningTemplate(Template):
 | 
			
		||||
        system: Optional[str] = None,
 | 
			
		||||
        tools: Optional[str] = None,
 | 
			
		||||
    ) -> list[tuple[list[int], list[int]]]:
 | 
			
		||||
        messages = deepcopy(messages)
 | 
			
		||||
        if self.enable_thinking is False:  # remove all cot
 | 
			
		||||
            for i in range(1, len(messages), 2):
 | 
			
		||||
                messages[i]["content"] = self.remove_thought(messages[i]["content"])
 | 
			
		||||
 | 
			
		||||
        encoded_messages = self._encode(tokenizer, messages, system, tools)
 | 
			
		||||
        for i in range(0, len(messages), 2):
 | 
			
		||||
            if (
 | 
			
		||||
                self.thought_words[0] not in messages[i + 1]["content"]
 | 
			
		||||
                and self.thought_words[1] not in messages[i + 1]["content"]
 | 
			
		||||
            ):
 | 
			
		||||
                if not self.enable_thinking:
 | 
			
		||||
            ):  # add empty cot
 | 
			
		||||
                if not self.enable_thinking:  # do not compute loss
 | 
			
		||||
                    encoded_messages[i] += self.get_thought_word_ids(tokenizer)
 | 
			
		||||
                else:
 | 
			
		||||
                else:  # do compute loss
 | 
			
		||||
                    encoded_messages[i + 1] = self.get_thought_word_ids(tokenizer) + encoded_messages[i + 1]
 | 
			
		||||
 | 
			
		||||
        return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
 | 
			
		||||
@ -463,7 +471,7 @@ def register_template(
 | 
			
		||||
    efficient_eos: bool = False,
 | 
			
		||||
    replace_eos: bool = False,
 | 
			
		||||
    replace_jinja_template: bool = False,
 | 
			
		||||
    enable_thinking: bool = True,
 | 
			
		||||
    enable_thinking: Optional[bool] = True,
 | 
			
		||||
    mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
 | 
			
		||||
    template_class: type["Template"] = Template,
 | 
			
		||||
) -> None:
 | 
			
		||||
 | 
			
		||||
@ -2566,6 +2566,14 @@ register_model_group(
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen2.5-Omni-7B-GPTQ-Int4": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-GPTQ-Int4",
 | 
			
		||||
        },
 | 
			
		||||
        "Qwen2.5-Omni-7B-AWQ": {
 | 
			
		||||
            DownloadSource.DEFAULT: "Qwen/Qwen2.5-Omni-7B-AWQ",
 | 
			
		||||
            DownloadSource.MODELSCOPE: "Qwen/Qwen2.5-Omni-7B-AWQ",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    template="qwen2_omni",
 | 
			
		||||
    multimodal=True,
 | 
			
		||||
 | 
			
		||||
@ -94,7 +94,9 @@ def check_version(requirement: str, mandatory: bool = False) -> None:
 | 
			
		||||
 | 
			
		||||
def check_dependencies() -> None:
 | 
			
		||||
    r"""Check the version of the required packages."""
 | 
			
		||||
    check_version("transformers>=4.45.0,<=4.51.3,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
 | 
			
		||||
    check_version(
 | 
			
		||||
        "transformers>=4.45.0,<=4.52.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0,!=4.52.0"
 | 
			
		||||
    )
 | 
			
		||||
    check_version("datasets>=2.16.0,<=3.6.0")
 | 
			
		||||
    check_version("accelerate>=0.34.0,<=1.7.0")
 | 
			
		||||
    check_version("peft>=0.14.0,<=0.15.2")
 | 
			
		||||
 | 
			
		||||
@ -119,7 +119,7 @@ class DataArguments:
 | 
			
		||||
        default=None,
 | 
			
		||||
        metadata={"help": "Override the default system message in the template."},
 | 
			
		||||
    )
 | 
			
		||||
    enable_thinking: bool = field(
 | 
			
		||||
    enable_thinking: Optional[bool] = field(
 | 
			
		||||
        default=True,
 | 
			
		||||
        metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -235,10 +235,6 @@ class ProcessorArguments:
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether to crop the image to patches for internvl."},
 | 
			
		||||
    )
 | 
			
		||||
    use_audio_in_video: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use audio in video inputs."},
 | 
			
		||||
    )
 | 
			
		||||
    video_max_pixels: int = field(
 | 
			
		||||
        default=256 * 256,
 | 
			
		||||
        metadata={"help": "The maximum number of pixels of video inputs."},
 | 
			
		||||
@ -255,6 +251,10 @@ class ProcessorArguments:
 | 
			
		||||
        default=128,
 | 
			
		||||
        metadata={"help": "The maximum number of sampled frames for video inputs."},
 | 
			
		||||
    )
 | 
			
		||||
    use_audio_in_video: bool = field(
 | 
			
		||||
        default=False,
 | 
			
		||||
        metadata={"help": "Whether or not to use audio in video inputs."},
 | 
			
		||||
    )
 | 
			
		||||
    audio_sampling_rate: int = field(
 | 
			
		||||
        default=16000,
 | 
			
		||||
        metadata={"help": "The sampling rate of audio inputs."},
 | 
			
		||||
 | 
			
		||||
@ -24,6 +24,7 @@ import transformers.models
 | 
			
		||||
from transformers.activations import ACT2FN
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.packages import is_transformers_version_greater_than
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@ -281,7 +282,7 @@ _register_composite_model(
 | 
			
		||||
    model_type="qwen2_vl",
 | 
			
		||||
    projector_key="visual.merger",
 | 
			
		||||
    vision_model_keys=["visual.patch_embed", "visual.blocks"],
 | 
			
		||||
    language_model_keys=["model", "lm_head"],
 | 
			
		||||
    language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
 | 
			
		||||
    lora_conflict_keys=["patch_embed"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -290,6 +291,6 @@ _register_composite_model(
 | 
			
		||||
    model_type="qwen2_5_vl",
 | 
			
		||||
    projector_key="visual.merger",
 | 
			
		||||
    vision_model_keys=["visual.patch_embed", "visual.blocks"],
 | 
			
		||||
    language_model_keys=["model", "lm_head"],
 | 
			
		||||
    language_model_keys=["language_model"] if is_transformers_version_greater_than("4.52.0") else ["model", "lm_head"],
 | 
			
		||||
    lora_conflict_keys=["patch_embed"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@ -85,8 +85,8 @@ def patch_processor(
 | 
			
		||||
    setattr(processor, "video_min_pixels", model_args.video_min_pixels)
 | 
			
		||||
    setattr(processor, "video_fps", model_args.video_fps)
 | 
			
		||||
    setattr(processor, "video_maxlen", model_args.video_maxlen)
 | 
			
		||||
    setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
 | 
			
		||||
    setattr(processor, "use_audio_in_video", model_args.use_audio_in_video)
 | 
			
		||||
    setattr(processor, "audio_sampling_rate", model_args.audio_sampling_rate)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def patch_config(
 | 
			
		||||
 | 
			
		||||
@ -121,11 +121,11 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
        return super()._get_train_sampler(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_batch_samples(self, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,6 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, ge
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    import torch.utils.data
 | 
			
		||||
    from transformers import PreTrainedModel, ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments
 | 
			
		||||
@ -119,12 +118,12 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        r"""Replace the sequential sampler of KTO Trainer created by trl with the random sampler."""
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return Trainer._get_train_sampler(self)
 | 
			
		||||
        return Trainer._get_train_sampler(self, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_batch_samples(self, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -70,11 +70,11 @@ class CustomTrainer(Trainer):
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
        return super()._get_train_sampler(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(self, model, inputs, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -78,11 +78,11 @@ class PairwiseTrainer(Trainer):
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
        return super()._get_train_sampler(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
 | 
			
		||||
@ -92,11 +92,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
    def _get_train_sampler(self, *args, **kwargs) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        if self.finetuning_args.disable_shuffling:
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
        return super()._get_train_sampler(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(self, model, inputs, *args, **kwargs):
 | 
			
		||||
 | 
			
		||||
@ -205,6 +205,14 @@ def load_eval_results(path: os.PathLike) -> str:
 | 
			
		||||
    return f"```json\n{result}\n```\n"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def calculate_pixels(pixels: str) -> int:
 | 
			
		||||
    r"""Calculate the number of pixels from the expression."""
 | 
			
		||||
    if "*" in pixels:
 | 
			
		||||
        return int(pixels.split("*")[0]) * int(pixels.split("*")[1])
 | 
			
		||||
    else:
 | 
			
		||||
        return int(pixels)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_ds_config() -> None:
 | 
			
		||||
    r"""Create deepspeed config in the current directory."""
 | 
			
		||||
    os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
@ -106,11 +106,11 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
 | 
			
		||||
                use_llama_pro = gr.Checkbox()
 | 
			
		||||
 | 
			
		||||
            with gr.Column():
 | 
			
		||||
                enable_thinking = gr.Checkbox(value=True)
 | 
			
		||||
                report_to = gr.Dropdown(
 | 
			
		||||
                    choices=["none", "all", "wandb", "mlflow", "neptune", "tensorboard"],
 | 
			
		||||
                    value=["none"],
 | 
			
		||||
                    choices=["none", "wandb", "mlflow", "neptune", "tensorboard", "all"],
 | 
			
		||||
                    value="none",
 | 
			
		||||
                    allow_custom_value=True,
 | 
			
		||||
                    multiselect=True,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    input_elems.update(
 | 
			
		||||
@ -126,6 +126,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
 | 
			
		||||
            mask_history,
 | 
			
		||||
            resize_vocab,
 | 
			
		||||
            use_llama_pro,
 | 
			
		||||
            enable_thinking,
 | 
			
		||||
            report_to,
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
@ -143,6 +144,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
 | 
			
		||||
            mask_history=mask_history,
 | 
			
		||||
            resize_vocab=resize_vocab,
 | 
			
		||||
            use_llama_pro=use_llama_pro,
 | 
			
		||||
            enable_thinking=enable_thinking,
 | 
			
		||||
            report_to=report_to,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
@ -231,6 +233,42 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with gr.Accordion(open=False) as mm_tab:
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            freeze_vision_tower = gr.Checkbox(value=True)
 | 
			
		||||
            freeze_multi_modal_projector = gr.Checkbox(value=True)
 | 
			
		||||
            freeze_language_model = gr.Checkbox(value=False)
 | 
			
		||||
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            image_max_pixels = gr.Textbox(value="768*768")
 | 
			
		||||
            image_min_pixels = gr.Textbox(value="32*32")
 | 
			
		||||
            video_max_pixels = gr.Textbox(value="256*256")
 | 
			
		||||
            video_min_pixels = gr.Textbox(value="16*16")
 | 
			
		||||
 | 
			
		||||
    input_elems.update(
 | 
			
		||||
        {
 | 
			
		||||
            freeze_vision_tower,
 | 
			
		||||
            freeze_multi_modal_projector,
 | 
			
		||||
            freeze_language_model,
 | 
			
		||||
            image_max_pixels,
 | 
			
		||||
            image_min_pixels,
 | 
			
		||||
            video_max_pixels,
 | 
			
		||||
            video_min_pixels,
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
    elem_dict.update(
 | 
			
		||||
        dict(
 | 
			
		||||
            mm_tab=mm_tab,
 | 
			
		||||
            freeze_vision_tower=freeze_vision_tower,
 | 
			
		||||
            freeze_multi_modal_projector=freeze_multi_modal_projector,
 | 
			
		||||
            freeze_language_model=freeze_language_model,
 | 
			
		||||
            image_max_pixels=image_max_pixels,
 | 
			
		||||
            image_min_pixels=image_min_pixels,
 | 
			
		||||
            video_max_pixels=video_max_pixels,
 | 
			
		||||
            video_min_pixels=video_min_pixels,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with gr.Accordion(open=False) as galore_tab:
 | 
			
		||||
        with gr.Row():
 | 
			
		||||
            use_galore = gr.Checkbox()
 | 
			
		||||
 | 
			
		||||
@ -871,6 +871,28 @@ LOCALES = {
 | 
			
		||||
            "info": "拡張ブロックのパラメータのみをトレーニングします。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "enable_thinking": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Enable thinking",
 | 
			
		||||
            "info": "Whether or not to enable thinking mode for reasoning models.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Включить мысли",
 | 
			
		||||
            "info": "Включить режим мысли для моделей решающего характера.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "启用思考模式",
 | 
			
		||||
            "info": "是否启用推理模型的思考模式。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "생각 모드 활성화",
 | 
			
		||||
            "info": "추론 모델의 생각 모드를 활성화할지 여부.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "思考モードを有効化",
 | 
			
		||||
            "info": "推論モデルの思考モードを有効にするかどうか。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "report_to": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Enable external logger",
 | 
			
		||||
@ -1374,6 +1396,177 @@ LOCALES = {
 | 
			
		||||
            "info": "PPO トレーニングにおいて報酬スコアをホワイトニング処理します。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "mm_tab": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Multimodal configurations",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Конфигурации мультимедиа",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "多模态参数设置",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "멀티모달 구성",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "多モーダル設定",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "freeze_vision_tower": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Freeze vision tower",
 | 
			
		||||
            "info": "Freeze the vision tower in the model.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Заморозить башню визиона",
 | 
			
		||||
            "info": "Заморозить башню визиона в модели.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "冻结视觉编码器",
 | 
			
		||||
            "info": "冻结模型中的视觉编码器。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "비전 타워 고정",
 | 
			
		||||
            "info": "모델의 비전 타워를 고정합니다.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "ビジョンタワーの固定",
 | 
			
		||||
            "info": "モデルのビジョンタワーを固定します。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "freeze_multi_modal_projector": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Freeze multi-modal projector",
 | 
			
		||||
            "info": "Freeze the multi-modal projector in the model.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Заморозить мультимодальный проектор",
 | 
			
		||||
            "info": "Заморозить мультимодальный проектор в модели.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "冻结多模态投影器",
 | 
			
		||||
            "info": "冻结模型中的多模态投影器。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "멀티모달 프로젝터 고정",
 | 
			
		||||
            "info": "모델의 멀티모달 프로젝터를 고정합니다.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "多モーダルプロジェクターの固定",
 | 
			
		||||
            "info": "モデルの多モーダルプロジェクターを固定します。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "freeze_language_model": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Freeze language model",
 | 
			
		||||
            "info": "Freeze the language model in the model.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Заморозить язык модели",
 | 
			
		||||
            "info": "Заморозить язык модели в модели.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "冻结语言模型",
 | 
			
		||||
            "info": "冻结模型中的语言模型。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "언어 모델 고정",
 | 
			
		||||
            "info": "모델의 언어 모델을 고정합니다.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "言語モデルの固定",
 | 
			
		||||
            "info": "モデルの言語モデルを固定します。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "image_max_pixels": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Image max pixels",
 | 
			
		||||
            "info": "The maximum number of pixels of image inputs.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Максимальное количество пикселей изображения",
 | 
			
		||||
            "info": "Максимальное количество пикселей изображения.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "图像最大像素",
 | 
			
		||||
            "info": "输入图像的最大像素数。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "이미지 최대 픽셀",
 | 
			
		||||
            "info": "이미지 입력의 최대 픽셀 수입니다.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "画像最大ピクセル",
 | 
			
		||||
            "info": "画像入力の最大ピクセル数です。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "image_min_pixels": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Image min pixels",
 | 
			
		||||
            "info": "The minimum number of pixels of image inputs.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Минимальное количество пикселей изображения",
 | 
			
		||||
            "info": "Минимальное количество пикселей изображения.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "图像最小像素",
 | 
			
		||||
            "info": "输入图像的最小像素数。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "이미지 최소 픽셀",
 | 
			
		||||
            "info": "이미지 입력의 최소 픽셀 수입니다.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "画像最小ピクセル",
 | 
			
		||||
            "info": "画像入力の最小ピクセル数です。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "video_max_pixels": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Video max pixels",
 | 
			
		||||
            "info": "The maximum number of pixels of video inputs.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Максимальное количество пикселей видео",
 | 
			
		||||
            "info": "Максимальное количество пикселей видео.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "视频最大像素",
 | 
			
		||||
            "info": "输入视频的最大像素数。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "비디오 최대 픽셀",
 | 
			
		||||
            "info": "비디오 입력의 최대 픽셀 수입니다.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "ビデオ最大ピクセル",
 | 
			
		||||
            "info": "ビデオ入力の最大ピクセル数です。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "video_min_pixels": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Video min pixels",
 | 
			
		||||
            "info": "The minimum number of pixels of video inputs.",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Минимальное количество пикселей видео",
 | 
			
		||||
            "info": "Минимальное количество пикселей видео.",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "视频最小像素",
 | 
			
		||||
            "info": "输入视频的最小像素数。",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "비디오 최소 픽셀",
 | 
			
		||||
            "info": "비디오 입력의 최소 픽셀 수입니다.",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "ビデオ最小ピクセル",
 | 
			
		||||
            "info": "ビデオ入力の最小ピクセル数です。",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "galore_tab": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "GaLore configurations",
 | 
			
		||||
@ -2468,23 +2661,6 @@ LOCALES = {
 | 
			
		||||
            "label": "HTML タグをエスケープ",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "enable_thinking": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "label": "Enable thinking",
 | 
			
		||||
        },
 | 
			
		||||
        "ru": {
 | 
			
		||||
            "label": "Включить мышление",
 | 
			
		||||
        },
 | 
			
		||||
        "zh": {
 | 
			
		||||
            "label": "启用思考",
 | 
			
		||||
        },
 | 
			
		||||
        "ko": {
 | 
			
		||||
            "label": "사고를 활성화하다",
 | 
			
		||||
        },
 | 
			
		||||
        "ja": {
 | 
			
		||||
            "label": "思考を可能にする",
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "clear_btn": {
 | 
			
		||||
        "en": {
 | 
			
		||||
            "value": "Clear history",
 | 
			
		||||
 | 
			
		||||
@ -29,6 +29,7 @@ from .common import (
 | 
			
		||||
    DEFAULT_CACHE_DIR,
 | 
			
		||||
    DEFAULT_CONFIG_DIR,
 | 
			
		||||
    abort_process,
 | 
			
		||||
    calculate_pixels,
 | 
			
		||||
    gen_cmd,
 | 
			
		||||
    get_save_dir,
 | 
			
		||||
    load_args,
 | 
			
		||||
@ -162,7 +163,15 @@ class Runner:
 | 
			
		||||
            mask_history=get("train.mask_history"),
 | 
			
		||||
            resize_vocab=get("train.resize_vocab"),
 | 
			
		||||
            use_llama_pro=get("train.use_llama_pro"),
 | 
			
		||||
            enable_thinking=get("train.enable_thinking"),
 | 
			
		||||
            report_to=get("train.report_to"),
 | 
			
		||||
            freeze_vision_tower=get("train.freeze_vision_tower"),
 | 
			
		||||
            freeze_multi_modal_projector=get("train.freeze_multi_modal_projector"),
 | 
			
		||||
            freeze_language_model=get("train.freeze_language_model"),
 | 
			
		||||
            image_max_pixels=calculate_pixels(get("train.image_max_pixels")),
 | 
			
		||||
            image_min_pixels=calculate_pixels(get("train.image_min_pixels")),
 | 
			
		||||
            video_max_pixels=calculate_pixels(get("train.video_max_pixels")),
 | 
			
		||||
            video_min_pixels=calculate_pixels(get("train.video_min_pixels")),
 | 
			
		||||
            use_galore=get("train.use_galore"),
 | 
			
		||||
            use_apollo=get("train.use_apollo"),
 | 
			
		||||
            use_badam=get("train.use_badam"),
 | 
			
		||||
@ -256,12 +265,6 @@ class Runner:
 | 
			
		||||
            args["badam_switch_interval"] = get("train.badam_switch_interval")
 | 
			
		||||
            args["badam_update_ratio"] = get("train.badam_update_ratio")
 | 
			
		||||
 | 
			
		||||
        # report_to
 | 
			
		||||
        if "none" in args["report_to"]:
 | 
			
		||||
            args["report_to"] = "none"
 | 
			
		||||
        elif "all" in args["report_to"]:
 | 
			
		||||
            args["report_to"] = "all"
 | 
			
		||||
 | 
			
		||||
        # swanlab config
 | 
			
		||||
        if get("train.use_swanlab"):
 | 
			
		||||
            args["swanlab_project"] = get("train.swanlab_project")
 | 
			
		||||
 | 
			
		||||
@ -135,8 +135,7 @@ def _check_plugin(
 | 
			
		||||
    expected_mm_inputs: dict[str, Any] = {},
 | 
			
		||||
    expected_no_mm_inputs: dict[str, Any] = {},
 | 
			
		||||
) -> None:
 | 
			
		||||
    # test omni_messages
 | 
			
		||||
    if plugin.__class__.__name__ == "Qwen2OmniPlugin":
 | 
			
		||||
    if plugin.__class__.__name__ == "Qwen2OmniPlugin":  # test omni_messages
 | 
			
		||||
        assert plugin.process_messages(OMNI_MESSAGES, IMAGES, NO_VIDEOS, AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
        assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, AUDIOS, tokenizer, processor) == (
 | 
			
		||||
            expected_input_ids,
 | 
			
		||||
@ -146,8 +145,7 @@ def _check_plugin(
 | 
			
		||||
            plugin.get_mm_inputs(IMAGES, NO_VIDEOS, AUDIOS, IMGLENS, NO_VIDLENS, AUDLENS, BATCH_IDS, processor),
 | 
			
		||||
            expected_mm_inputs,
 | 
			
		||||
        )
 | 
			
		||||
    # test mm_messages
 | 
			
		||||
    if plugin.__class__.__name__ != "BasePlugin":
 | 
			
		||||
    elif plugin.__class__.__name__ != "BasePlugin":  # test mm_messages
 | 
			
		||||
        assert plugin.process_messages(MM_MESSAGES, IMAGES, NO_VIDEOS, NO_AUDIOS, processor) == expected_mm_messages
 | 
			
		||||
        assert plugin.process_token_ids(INPUT_IDS, LABELS, IMAGES, NO_VIDEOS, NO_AUDIOS, tokenizer, processor) == (
 | 
			
		||||
            expected_input_ids,
 | 
			
		||||
@ -201,7 +199,7 @@ def test_gemma3_plugin():
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail(reason="Unknown error.")
 | 
			
		||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
 | 
			
		||||
def test_internvl_plugin():
 | 
			
		||||
    image_seqlen = 256
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path="OpenGVLab/InternVL3-1B-hf")
 | 
			
		||||
@ -219,7 +217,7 @@ def test_internvl_plugin():
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail(reason="Unknown error.")
 | 
			
		||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.51.0"), reason="Requires transformers>=4.51.0")
 | 
			
		||||
def test_llama4_plugin():
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path=TINY_LLAMA4)
 | 
			
		||||
    processor = tokenizer_module["processor"]
 | 
			
		||||
@ -321,10 +319,9 @@ def test_pixtral_plugin():
 | 
			
		||||
    _check_plugin(**check_inputs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail(reason="Unknown error.")
 | 
			
		||||
@pytest.mark.skipif(not is_transformers_version_greater_than("4.52.0"), reason="Requires transformers>=4.52.0")
 | 
			
		||||
def test_qwen2_omni_plugin():
 | 
			
		||||
    image_seqlen = 4
 | 
			
		||||
    audio_seqlen = 2
 | 
			
		||||
    image_seqlen, audio_seqlen = 4, 2
 | 
			
		||||
    tokenizer_module = _load_tokenizer_module(model_name_or_path="Qwen/Qwen2.5-Omni-7B")
 | 
			
		||||
    qwen2_omni_plugin = get_mm_plugin(
 | 
			
		||||
        name="qwen2_omni", audio_token="<|AUDIO|>", image_token="<|IMAGE|>", video_token="<|VIDEO|>"
 | 
			
		||||
 | 
			
		||||
@ -127,20 +127,21 @@ def test_encode_multiturn(use_fast: bool):
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("use_fast", [True, False])
 | 
			
		||||
@pytest.mark.parametrize("cot_messages", [True, False])
 | 
			
		||||
@pytest.mark.parametrize("enable_thinking", [True, False])
 | 
			
		||||
@pytest.mark.parametrize("enable_thinking", [True, False, None])
 | 
			
		||||
def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
 | 
			
		||||
    messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
 | 
			
		||||
    input_messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
 | 
			
		||||
    data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, messages)
 | 
			
		||||
    prompt_ids, answer_ids = template.encode_oneturn(tokenizer, input_messages)
 | 
			
		||||
    output_messages = MESSAGES if enable_thinking is False else input_messages
 | 
			
		||||
    prompt_str = (
 | 
			
		||||
        f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
        f"<|im_start|>user\n{output_messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
        f"{MESSAGES[1]['content']}<|im_end|>\n"
 | 
			
		||||
        f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
        f"<|im_start|>user\n{output_messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
    )
 | 
			
		||||
    answer_str = f"{messages[3]['content']}<|im_end|>\n"
 | 
			
		||||
    if not cot_messages:
 | 
			
		||||
    answer_str = f"{output_messages[3]['content']}<|im_end|>\n"
 | 
			
		||||
    if not cot_messages or enable_thinking is False:
 | 
			
		||||
        if enable_thinking:
 | 
			
		||||
            answer_str = "<think>\n\n</think>\n\n" + answer_str
 | 
			
		||||
        else:
 | 
			
		||||
@ -151,18 +152,19 @@ def test_reasoning_encode_oneturn(use_fast: bool, cot_messages: bool, enable_thi
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("use_fast", [True, False])
 | 
			
		||||
@pytest.mark.parametrize("cot_messages", [True, False])
 | 
			
		||||
@pytest.mark.parametrize("enable_thinking", [True, False])
 | 
			
		||||
@pytest.mark.parametrize("enable_thinking", [True, False, None])
 | 
			
		||||
def test_reasoning_encode_multiturn(use_fast: bool, cot_messages: bool, enable_thinking: bool):
 | 
			
		||||
    messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
 | 
			
		||||
    input_messages = MESSAGES_WITH_THOUGHT if cot_messages else MESSAGES
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", use_fast=use_fast)
 | 
			
		||||
    data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    encoded_pairs = template.encode_multiturn(tokenizer, messages)
 | 
			
		||||
    prompt_str_1 = f"<|im_start|>user\n{messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
    answer_str_1 = f"{messages[1]['content']}<|im_end|>\n"
 | 
			
		||||
    prompt_str_2 = f"<|im_start|>user\n{messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
    answer_str_2 = f"{messages[3]['content']}<|im_end|>\n"
 | 
			
		||||
    if not cot_messages:
 | 
			
		||||
    encoded_pairs = template.encode_multiturn(tokenizer, input_messages)
 | 
			
		||||
    output_messages = MESSAGES if enable_thinking is False else input_messages
 | 
			
		||||
    prompt_str_1 = f"<|im_start|>user\n{output_messages[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
    answer_str_1 = f"{output_messages[1]['content']}<|im_end|>\n"
 | 
			
		||||
    prompt_str_2 = f"<|im_start|>user\n{output_messages[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
 | 
			
		||||
    answer_str_2 = f"{output_messages[3]['content']}<|im_end|>\n"
 | 
			
		||||
    if not cot_messages or enable_thinking is False:
 | 
			
		||||
        if enable_thinking:
 | 
			
		||||
            answer_str_1 = "<think>\n\n</think>\n\n" + answer_str_1
 | 
			
		||||
            answer_str_2 = "<think>\n\n</think>\n\n" + answer_str_2
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ import pytest
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import AutoConfig, AutoModelForVision2Seq
 | 
			
		||||
 | 
			
		||||
from llamafactory.extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from llamafactory.hparams import FinetuningArguments, ModelArguments
 | 
			
		||||
from llamafactory.model.adapter import init_adapter
 | 
			
		||||
 | 
			
		||||
@ -45,10 +46,12 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
 | 
			
		||||
            assert param.requires_grad != freeze_language_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
 | 
			
		||||
def test_visual_lora(freeze_vision_tower: bool):
 | 
			
		||||
@pytest.mark.parametrize("freeze_vision_tower,freeze_language_model", ((False, False), (False, True), (True, False)))
 | 
			
		||||
def test_visual_lora(freeze_vision_tower: bool, freeze_language_model: bool):
 | 
			
		||||
    model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
 | 
			
		||||
    finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower)
 | 
			
		||||
    finetuning_args = FinetuningArguments(
 | 
			
		||||
        finetuning_type="lora", freeze_vision_tower=freeze_vision_tower, freeze_language_model=freeze_language_model
 | 
			
		||||
    )
 | 
			
		||||
    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
 | 
			
		||||
    with torch.device("meta"):
 | 
			
		||||
        model = AutoModelForVision2Seq.from_config(config)
 | 
			
		||||
@ -61,10 +64,15 @@ def test_visual_lora(freeze_vision_tower: bool):
 | 
			
		||||
        else:
 | 
			
		||||
            frozen_params.add(name)
 | 
			
		||||
 | 
			
		||||
    if freeze_vision_tower:
 | 
			
		||||
        assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" not in trainable_params
 | 
			
		||||
    if is_transformers_version_greater_than("4.52.0"):
 | 
			
		||||
        visual_param_name = "base_model.model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
 | 
			
		||||
        language_param_name = "base_model.model.model.language_model.layers.0.self_attn.q_proj.lora_A.default.weight"
 | 
			
		||||
        merger_param_name = "base_model.model.model.visual.merger.lora_A.default.weight"
 | 
			
		||||
    else:
 | 
			
		||||
        assert "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight" in trainable_params
 | 
			
		||||
        visual_param_name = "base_model.model.visual.blocks.0.attn.qkv.lora_A.default.weight"
 | 
			
		||||
        language_param_name = "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight"
 | 
			
		||||
        merger_param_name = "base_model.model.visual.merger.lora_A.default.weight"
 | 
			
		||||
 | 
			
		||||
    assert "merger" not in trainable_params
 | 
			
		||||
    assert "base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight" in trainable_params
 | 
			
		||||
    assert (visual_param_name in trainable_params) != freeze_vision_tower
 | 
			
		||||
    assert (language_param_name in trainable_params) != freeze_language_model
 | 
			
		||||
    assert (merger_param_name in trainable_params) is False
 | 
			
		||||
 | 
			
		||||
@ -1,2 +1,2 @@
 | 
			
		||||
# change if test fails or cache is outdated
 | 
			
		||||
0.9.3.106
 | 
			
		||||
0.9.3.107
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user