mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	[breaking] support transformers 4.48 (#6628)
Former-commit-id: f154ab175c513a4d7bb866bf2cffc34b77b50508
This commit is contained in:
		
							parent
							
								
									e71737351f
								
							
						
					
					
						commit
						222423bcef
					
				
							
								
								
									
										2
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -22,10 +22,10 @@ jobs:
 | 
			
		||||
      fail-fast: false
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version:
 | 
			
		||||
          - "3.8"  # TODO: remove py38 in next transformers release
 | 
			
		||||
          - "3.9"
 | 
			
		||||
          - "3.10"
 | 
			
		||||
          - "3.11"
 | 
			
		||||
          - "3.12"
 | 
			
		||||
        os:
 | 
			
		||||
          - "ubuntu-latest"
 | 
			
		||||
          - "windows-latest"
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								README.md
									
									
									
									
									
								
							@ -377,11 +377,11 @@ huggingface-cli login
 | 
			
		||||
 | 
			
		||||
| Mandatory    | Minimum | Recommend |
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.8     | 3.11      |
 | 
			
		||||
| python       | 3.9     | 3.10      |
 | 
			
		||||
| torch        | 1.13.1  | 2.4.0     |
 | 
			
		||||
| transformers | 4.41.2  | 4.43.4    |
 | 
			
		||||
| datasets     | 2.16.0  | 2.20.0    |
 | 
			
		||||
| accelerate   | 0.30.1  | 0.32.0    |
 | 
			
		||||
| transformers | 4.41.2  | 4.45.2    |
 | 
			
		||||
| datasets     | 2.16.0  | 3.2.0     |
 | 
			
		||||
| accelerate   | 0.34.0  | 1.2.1     |
 | 
			
		||||
| peft         | 0.11.1  | 0.12.0    |
 | 
			
		||||
| trl          | 0.8.6   | 0.9.6     |
 | 
			
		||||
 | 
			
		||||
@ -390,8 +390,8 @@ huggingface-cli login
 | 
			
		||||
| CUDA         | 11.6    | 12.2      |
 | 
			
		||||
| deepspeed    | 0.10.0  | 0.14.0    |
 | 
			
		||||
| bitsandbytes | 0.39.0  | 0.43.1    |
 | 
			
		||||
| vllm         | 0.4.3   | 0.5.0     |
 | 
			
		||||
| flash-attn   | 2.3.0   | 2.6.3     |
 | 
			
		||||
| vllm         | 0.4.3   | 0.6.6     |
 | 
			
		||||
| flash-attn   | 2.3.0   | 2.7.2     |
 | 
			
		||||
 | 
			
		||||
### Hardware Requirement
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								README_zh.md
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								README_zh.md
									
									
									
									
									
								
							@ -379,11 +379,11 @@ huggingface-cli login
 | 
			
		||||
 | 
			
		||||
| 必需项       | 至少     | 推荐      |
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.8     | 3.11      |
 | 
			
		||||
| python       | 3.9     | 3.10      |
 | 
			
		||||
| torch        | 1.13.1  | 2.4.0     |
 | 
			
		||||
| transformers | 4.41.2  | 4.43.4    |
 | 
			
		||||
| datasets     | 2.16.0  | 2.20.0    |
 | 
			
		||||
| accelerate   | 0.30.1  | 0.32.0    |
 | 
			
		||||
| transformers | 4.41.2  | 4.45.2    |
 | 
			
		||||
| datasets     | 2.16.0  | 3.2.0     |
 | 
			
		||||
| accelerate   | 0.34.0  | 1.2.1     |
 | 
			
		||||
| peft         | 0.11.1  | 0.12.0    |
 | 
			
		||||
| trl          | 0.8.6   | 0.9.6     |
 | 
			
		||||
 | 
			
		||||
@ -392,8 +392,8 @@ huggingface-cli login
 | 
			
		||||
| CUDA         | 11.6    | 12.2      |
 | 
			
		||||
| deepspeed    | 0.10.0  | 0.14.0    |
 | 
			
		||||
| bitsandbytes | 0.39.0  | 0.43.1    |
 | 
			
		||||
| vllm         | 0.4.3   | 0.5.0     |
 | 
			
		||||
| flash-attn   | 2.3.0   | 2.6.3     |
 | 
			
		||||
| vllm         | 0.4.3   | 0.6.6     |
 | 
			
		||||
| flash-attn   | 2.3.0   | 2.7.2     |
 | 
			
		||||
 | 
			
		||||
### 硬件依赖
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,10 @@
 | 
			
		||||
transformers>=4.41.2,<=4.46.1
 | 
			
		||||
datasets>=2.16.0,<=3.1.0
 | 
			
		||||
accelerate>=0.34.0,<=1.0.1
 | 
			
		||||
transformers>=4.41.2,<=4.45.2;python_version<'3.10'
 | 
			
		||||
transformers>=4.41.2,<=4.48.1,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
 | 
			
		||||
datasets>=2.16.0,<=3.2.0
 | 
			
		||||
accelerate>=0.34.0,<=1.2.1
 | 
			
		||||
peft>=0.11.1,<=0.12.0
 | 
			
		||||
trl>=0.8.6,<=0.9.6
 | 
			
		||||
tokenizers>=0.19.0,<0.20.4
 | 
			
		||||
tokenizers>=0.19.0,<=0.21.0
 | 
			
		||||
gradio>=4.38.0,<=5.12.0
 | 
			
		||||
pandas>=2.0.0
 | 
			
		||||
scipy
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								setup.py
									
									
									
									
									
								
							@ -46,7 +46,7 @@ extra_require = {
 | 
			
		||||
    "torch": ["torch>=1.13.1"],
 | 
			
		||||
    "torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
 | 
			
		||||
    "metrics": ["nltk", "jieba", "rouge-chinese"],
 | 
			
		||||
    "deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
 | 
			
		||||
    "deepspeed": ["deepspeed>=0.10.0,<=0.16.2"],
 | 
			
		||||
    "liger-kernel": ["liger-kernel"],
 | 
			
		||||
    "bitsandbytes": ["bitsandbytes>=0.39.0"],
 | 
			
		||||
    "hqq": ["hqq"],
 | 
			
		||||
@ -92,7 +92,7 @@ def main():
 | 
			
		||||
        url="https://github.com/hiyouga/LLaMA-Factory",
 | 
			
		||||
        package_dir={"": "src"},
 | 
			
		||||
        packages=find_packages("src"),
 | 
			
		||||
        python_requires=">=3.8.0",
 | 
			
		||||
        python_requires=">=3.9.0",
 | 
			
		||||
        install_requires=get_requires(),
 | 
			
		||||
        extras_require=extra_require,
 | 
			
		||||
        entry_points={"console_scripts": get_console_scripts()},
 | 
			
		||||
@ -104,10 +104,10 @@ def main():
 | 
			
		||||
            "License :: OSI Approved :: Apache Software License",
 | 
			
		||||
            "Operating System :: OS Independent",
 | 
			
		||||
            "Programming Language :: Python :: 3",
 | 
			
		||||
            "Programming Language :: Python :: 3.8",
 | 
			
		||||
            "Programming Language :: Python :: 3.9",
 | 
			
		||||
            "Programming Language :: Python :: 3.10",
 | 
			
		||||
            "Programming Language :: Python :: 3.11",
 | 
			
		||||
            "Programming Language :: Python :: 3.12",
 | 
			
		||||
            "Topic :: Scientific/Engineering :: Artificial Intelligence",
 | 
			
		||||
        ],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -20,17 +20,17 @@ Level:
 | 
			
		||||
 | 
			
		||||
Dependency graph:
 | 
			
		||||
  main:
 | 
			
		||||
    transformers>=4.41.2,<=4.46.1
 | 
			
		||||
    datasets>=2.16.0,<=3.1.0
 | 
			
		||||
    accelerate>=0.34.0,<=1.0.1
 | 
			
		||||
    transformers>=4.41.2,<=4.48.1,!=4.46.*,!=4.47.*,!=4.48.0
 | 
			
		||||
    datasets>=2.16.0,<=3.2.0
 | 
			
		||||
    accelerate>=0.34.0,<=1.2.1
 | 
			
		||||
    peft>=0.11.1,<=0.12.0
 | 
			
		||||
    trl>=0.8.6,<=0.9.6
 | 
			
		||||
  attention:
 | 
			
		||||
    transformers>=4.42.4 (gemma+fa2)
 | 
			
		||||
  longlora:
 | 
			
		||||
    transformers>=4.41.2,<=4.46.1
 | 
			
		||||
    transformers>=4.41.2,<4.48.0
 | 
			
		||||
  packing:
 | 
			
		||||
    transformers>=4.43.0,<=4.46.1
 | 
			
		||||
    transformers>=4.43.0,<=4.48.1
 | 
			
		||||
 | 
			
		||||
Disable version checking: DISABLE_VERSION_CHECK=1
 | 
			
		||||
Enable VRAM recording: RECORD_VRAM=1
 | 
			
		||||
 | 
			
		||||
@ -34,6 +34,7 @@ from transformers.utils import (
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
from . import logging
 | 
			
		||||
from .packages import is_transformers_version_greater_than
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
 | 
			
		||||
@ -93,11 +94,13 @@ def check_dependencies() -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Checks the version of the required packages.
 | 
			
		||||
    """
 | 
			
		||||
    check_version("transformers>=4.41.2,<=4.46.1")
 | 
			
		||||
    check_version("datasets>=2.16.0,<=3.1.0")
 | 
			
		||||
    check_version("accelerate>=0.34.0,<=1.0.1")
 | 
			
		||||
    check_version("transformers>=4.41.2,<=4.48.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
 | 
			
		||||
    check_version("datasets>=2.16.0,<=3.2.0")
 | 
			
		||||
    check_version("accelerate>=0.34.0,<=1.2.1")
 | 
			
		||||
    check_version("peft>=0.11.1,<=0.12.0")
 | 
			
		||||
    check_version("trl>=0.8.6,<=0.9.6")
 | 
			
		||||
    if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
 | 
			
		||||
        logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
 | 
			
		||||
 | 
			
		||||
@ -87,11 +87,6 @@ def is_transformers_version_greater_than(content: str):
 | 
			
		||||
    return _get_package_version("transformers") >= version.parse(content)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@lru_cache
 | 
			
		||||
def is_transformers_version_equal_to_4_46():
 | 
			
		||||
    return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_uvicorn_available():
 | 
			
		||||
    return _is_package_available("uvicorn")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _apply_llama_patch() -> None:
 | 
			
		||||
    check_version("transformers>=4.41.2,<=4.46.1")
 | 
			
		||||
    check_version("transformers>=4.41.2,<4.48.0")
 | 
			
		||||
    LlamaAttention.forward = llama_attention_forward
 | 
			
		||||
    LlamaFlashAttention2.forward = llama_flash_attention_2_forward
 | 
			
		||||
    LlamaSdpaAttention.forward = llama_sdpa_attention_forward
 | 
			
		||||
 | 
			
		||||
@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
 | 
			
		||||
    if not is_trainable or not model_args.block_diag_attn:
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    check_version("transformers>=4.43.0,<=4.46.1")
 | 
			
		||||
    check_version("transformers>=4.43.0,<=4.48.1")
 | 
			
		||||
    transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
 | 
			
		||||
    logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
 | 
			
		||||
from ...extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
 | 
			
		||||
 | 
			
		||||
@ -282,19 +282,12 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
 | 
			
		||||
        Subclass and override to accept extra kwargs.
 | 
			
		||||
        """
 | 
			
		||||
        loss = super().compute_loss(model, inputs, return_outputs)
 | 
			
		||||
        if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
 | 
			
		||||
            if return_outputs:
 | 
			
		||||
                loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
 | 
			
		||||
            else:
 | 
			
		||||
                loss = loss / self.args.gradient_accumulation_steps
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
        return super().compute_loss(model, inputs, return_outputs)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def log(self, logs: Dict[str, float]) -> None:
 | 
			
		||||
    def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Log `logs` on the various objects watching training, including stored metrics.
 | 
			
		||||
        """
 | 
			
		||||
@ -318,4 +311,4 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
            if not key.startswith("dummy_"):
 | 
			
		||||
                logs[key] = metric
 | 
			
		||||
 | 
			
		||||
        return Trainer.log(self, logs)
 | 
			
		||||
        return Trainer.log(self, logs, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
 | 
			
		||||
from ...extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
 | 
			
		||||
 | 
			
		||||
@ -256,19 +256,12 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
 | 
			
		||||
        Subclass and override to accept extra kwargs.
 | 
			
		||||
        """
 | 
			
		||||
        loss = super().compute_loss(model, inputs, return_outputs)
 | 
			
		||||
        if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
 | 
			
		||||
            if return_outputs:
 | 
			
		||||
                loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
 | 
			
		||||
            else:
 | 
			
		||||
                loss = loss / self.args.gradient_accumulation_steps
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
        return super().compute_loss(model, inputs, return_outputs)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def log(self, logs: Dict[str, float]) -> None:
 | 
			
		||||
    def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Log `logs` on the various objects watching training, including stored metrics.
 | 
			
		||||
        """
 | 
			
		||||
@ -304,4 +297,4 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
            if not key.startswith("dummy_"):
 | 
			
		||||
                logs[key] = metric
 | 
			
		||||
 | 
			
		||||
        return Trainer.log(self, logs)
 | 
			
		||||
        return Trainer.log(self, logs, *args, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,7 @@
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
@ -25,7 +25,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from transformers import PreTrainedModel, ProcessorMixin
 | 
			
		||||
    from transformers import ProcessorMixin
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments
 | 
			
		||||
 | 
			
		||||
@ -72,21 +72,3 @@ class CustomTrainer(Trainer):
 | 
			
		||||
            return torch.utils.data.SequentialSampler(self.train_dataset)
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
 | 
			
		||||
 | 
			
		||||
        It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
 | 
			
		||||
        """
 | 
			
		||||
        loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
 | 
			
		||||
        if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
 | 
			
		||||
            if return_outputs:
 | 
			
		||||
                loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
 | 
			
		||||
            else:
 | 
			
		||||
                loss = loss / self.args.gradient_accumulation_steps
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ from transformers import Trainer
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras import logging
 | 
			
		||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
 | 
			
		||||
from ...extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
 | 
			
		||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
@ -107,10 +107,6 @@ class PairwiseTrainer(Trainer):
 | 
			
		||||
        chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
 | 
			
		||||
 | 
			
		||||
        loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
 | 
			
		||||
 | 
			
		||||
        if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
 | 
			
		||||
            loss /= self.args.gradient_accumulation_steps  # fixes the loss value for transformers 4.46.0-4.46.1
 | 
			
		||||
 | 
			
		||||
        if return_outputs:
 | 
			
		||||
            return loss, (loss, chosen_scores, rejected_scores)
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from torch.utils.data import Dataset
 | 
			
		||||
    from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers import PreTrainedTokenizer, ProcessorMixin
 | 
			
		||||
    from transformers.trainer import PredictionOutput
 | 
			
		||||
 | 
			
		||||
    from ...hparams import FinetuningArguments
 | 
			
		||||
@ -88,24 +88,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
 | 
			
		||||
        return super()._get_train_sampler()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
 | 
			
		||||
 | 
			
		||||
        It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
 | 
			
		||||
        """
 | 
			
		||||
        loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
 | 
			
		||||
        if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
 | 
			
		||||
            if return_outputs:
 | 
			
		||||
                loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
 | 
			
		||||
            else:
 | 
			
		||||
                loss = loss / self.args.gradient_accumulation_steps
 | 
			
		||||
 | 
			
		||||
        return loss
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def prediction_step(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ from transformers.utils import is_torch_npu_available
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
 | 
			
		||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
 | 
			
		||||
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
 | 
			
		||||
from ..extras.packages import is_gradio_available
 | 
			
		||||
from .common import (
 | 
			
		||||
    DEFAULT_CACHE_DIR,
 | 
			
		||||
    DEFAULT_CONFIG_DIR,
 | 
			
		||||
@ -180,7 +180,7 @@ class Runner:
 | 
			
		||||
            plot_loss=True,
 | 
			
		||||
            trust_remote_code=True,
 | 
			
		||||
            ddp_timeout=180000000,
 | 
			
		||||
            include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True,  # FIXME
 | 
			
		||||
            include_num_input_tokens_seen=True,
 | 
			
		||||
        )
 | 
			
		||||
        args.update(json.loads(get("train.extra_args")))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -14,8 +14,10 @@
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
 | 
			
		||||
 | 
			
		||||
from llamafactory.extras.packages import is_transformers_version_greater_than
 | 
			
		||||
from llamafactory.train.test_utils import load_infer_model
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -27,6 +29,7 @@ INFER_ARGS = {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.")
 | 
			
		||||
def test_attention():
 | 
			
		||||
    attention_available = ["disabled"]
 | 
			
		||||
    if is_torch_sdpa_available():
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user