mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-14 15:52:49 +08:00
[breaking] support transformers 4.48 (#6628)
Former-commit-id: f154ab175c513a4d7bb866bf2cffc34b77b50508
This commit is contained in:
parent
e71737351f
commit
222423bcef
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -22,10 +22,10 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
- "3.8" # TODO: remove py38 in next transformers release
|
|
||||||
- "3.9"
|
- "3.9"
|
||||||
- "3.10"
|
- "3.10"
|
||||||
- "3.11"
|
- "3.11"
|
||||||
|
- "3.12"
|
||||||
os:
|
os:
|
||||||
- "ubuntu-latest"
|
- "ubuntu-latest"
|
||||||
- "windows-latest"
|
- "windows-latest"
|
||||||
|
12
README.md
12
README.md
@ -377,11 +377,11 @@ huggingface-cli login
|
|||||||
|
|
||||||
| Mandatory | Minimum | Recommend |
|
| Mandatory | Minimum | Recommend |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.11 |
|
| python | 3.9 | 3.10 |
|
||||||
| torch | 1.13.1 | 2.4.0 |
|
| torch | 1.13.1 | 2.4.0 |
|
||||||
| transformers | 4.41.2 | 4.43.4 |
|
| transformers | 4.41.2 | 4.45.2 |
|
||||||
| datasets | 2.16.0 | 2.20.0 |
|
| datasets | 2.16.0 | 3.2.0 |
|
||||||
| accelerate | 0.30.1 | 0.32.0 |
|
| accelerate | 0.34.0 | 1.2.1 |
|
||||||
| peft | 0.11.1 | 0.12.0 |
|
| peft | 0.11.1 | 0.12.0 |
|
||||||
| trl | 0.8.6 | 0.9.6 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
@ -390,8 +390,8 @@ huggingface-cli login
|
|||||||
| CUDA | 11.6 | 12.2 |
|
| CUDA | 11.6 | 12.2 |
|
||||||
| deepspeed | 0.10.0 | 0.14.0 |
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.5.0 |
|
| vllm | 0.4.3 | 0.6.6 |
|
||||||
| flash-attn | 2.3.0 | 2.6.3 |
|
| flash-attn | 2.3.0 | 2.7.2 |
|
||||||
|
|
||||||
### Hardware Requirement
|
### Hardware Requirement
|
||||||
|
|
||||||
|
12
README_zh.md
12
README_zh.md
@ -379,11 +379,11 @@ huggingface-cli login
|
|||||||
|
|
||||||
| 必需项 | 至少 | 推荐 |
|
| 必需项 | 至少 | 推荐 |
|
||||||
| ------------ | ------- | --------- |
|
| ------------ | ------- | --------- |
|
||||||
| python | 3.8 | 3.11 |
|
| python | 3.9 | 3.10 |
|
||||||
| torch | 1.13.1 | 2.4.0 |
|
| torch | 1.13.1 | 2.4.0 |
|
||||||
| transformers | 4.41.2 | 4.43.4 |
|
| transformers | 4.41.2 | 4.45.2 |
|
||||||
| datasets | 2.16.0 | 2.20.0 |
|
| datasets | 2.16.0 | 3.2.0 |
|
||||||
| accelerate | 0.30.1 | 0.32.0 |
|
| accelerate | 0.34.0 | 1.2.1 |
|
||||||
| peft | 0.11.1 | 0.12.0 |
|
| peft | 0.11.1 | 0.12.0 |
|
||||||
| trl | 0.8.6 | 0.9.6 |
|
| trl | 0.8.6 | 0.9.6 |
|
||||||
|
|
||||||
@ -392,8 +392,8 @@ huggingface-cli login
|
|||||||
| CUDA | 11.6 | 12.2 |
|
| CUDA | 11.6 | 12.2 |
|
||||||
| deepspeed | 0.10.0 | 0.14.0 |
|
| deepspeed | 0.10.0 | 0.14.0 |
|
||||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||||
| vllm | 0.4.3 | 0.5.0 |
|
| vllm | 0.4.3 | 0.6.6 |
|
||||||
| flash-attn | 2.3.0 | 2.6.3 |
|
| flash-attn | 2.3.0 | 2.7.2 |
|
||||||
|
|
||||||
### 硬件依赖
|
### 硬件依赖
|
||||||
|
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
transformers>=4.41.2,<=4.46.1
|
transformers>=4.41.2,<=4.45.2;python_version<'3.10'
|
||||||
datasets>=2.16.0,<=3.1.0
|
transformers>=4.41.2,<=4.48.1,!=4.46.*,!=4.47.*,!=4.48.0;python_version>='3.10'
|
||||||
accelerate>=0.34.0,<=1.0.1
|
datasets>=2.16.0,<=3.2.0
|
||||||
|
accelerate>=0.34.0,<=1.2.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
tokenizers>=0.19.0,<0.20.4
|
tokenizers>=0.19.0,<=0.21.0
|
||||||
gradio>=4.38.0,<=5.12.0
|
gradio>=4.38.0,<=5.12.0
|
||||||
pandas>=2.0.0
|
pandas>=2.0.0
|
||||||
scipy
|
scipy
|
||||||
|
6
setup.py
6
setup.py
@ -46,7 +46,7 @@ extra_require = {
|
|||||||
"torch": ["torch>=1.13.1"],
|
"torch": ["torch>=1.13.1"],
|
||||||
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
"torch-npu": ["torch==2.1.0", "torch-npu==2.1.0.post3", "decorator"],
|
||||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||||
"deepspeed": ["deepspeed>=0.10.0,<=0.14.4"],
|
"deepspeed": ["deepspeed>=0.10.0,<=0.16.2"],
|
||||||
"liger-kernel": ["liger-kernel"],
|
"liger-kernel": ["liger-kernel"],
|
||||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||||
"hqq": ["hqq"],
|
"hqq": ["hqq"],
|
||||||
@ -92,7 +92,7 @@ def main():
|
|||||||
url="https://github.com/hiyouga/LLaMA-Factory",
|
url="https://github.com/hiyouga/LLaMA-Factory",
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
packages=find_packages("src"),
|
packages=find_packages("src"),
|
||||||
python_requires=">=3.8.0",
|
python_requires=">=3.9.0",
|
||||||
install_requires=get_requires(),
|
install_requires=get_requires(),
|
||||||
extras_require=extra_require,
|
extras_require=extra_require,
|
||||||
entry_points={"console_scripts": get_console_scripts()},
|
entry_points={"console_scripts": get_console_scripts()},
|
||||||
@ -104,10 +104,10 @@ def main():
|
|||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -20,17 +20,17 @@ Level:
|
|||||||
|
|
||||||
Dependency graph:
|
Dependency graph:
|
||||||
main:
|
main:
|
||||||
transformers>=4.41.2,<=4.46.1
|
transformers>=4.41.2,<=4.48.1,!=4.46.*,!=4.47.*,!=4.48.0
|
||||||
datasets>=2.16.0,<=3.1.0
|
datasets>=2.16.0,<=3.2.0
|
||||||
accelerate>=0.34.0,<=1.0.1
|
accelerate>=0.34.0,<=1.2.1
|
||||||
peft>=0.11.1,<=0.12.0
|
peft>=0.11.1,<=0.12.0
|
||||||
trl>=0.8.6,<=0.9.6
|
trl>=0.8.6,<=0.9.6
|
||||||
attention:
|
attention:
|
||||||
transformers>=4.42.4 (gemma+fa2)
|
transformers>=4.42.4 (gemma+fa2)
|
||||||
longlora:
|
longlora:
|
||||||
transformers>=4.41.2,<=4.46.1
|
transformers>=4.41.2,<4.48.0
|
||||||
packing:
|
packing:
|
||||||
transformers>=4.43.0,<=4.46.1
|
transformers>=4.43.0,<=4.48.1
|
||||||
|
|
||||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||||
Enable VRAM recording: RECORD_VRAM=1
|
Enable VRAM recording: RECORD_VRAM=1
|
||||||
|
@ -34,6 +34,7 @@ from transformers.utils import (
|
|||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
|
from .packages import is_transformers_version_greater_than
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
@ -93,11 +94,13 @@ def check_dependencies() -> None:
|
|||||||
r"""
|
r"""
|
||||||
Checks the version of the required packages.
|
Checks the version of the required packages.
|
||||||
"""
|
"""
|
||||||
check_version("transformers>=4.41.2,<=4.46.1")
|
check_version("transformers>=4.41.2,<=4.48.1,!=4.46.0,!=4.46.1,!=4.46.2,!=4.46.3,!=4.47.0,!=4.47.1,!=4.48.0")
|
||||||
check_version("datasets>=2.16.0,<=3.1.0")
|
check_version("datasets>=2.16.0,<=3.2.0")
|
||||||
check_version("accelerate>=0.34.0,<=1.0.1")
|
check_version("accelerate>=0.34.0,<=1.2.1")
|
||||||
check_version("peft>=0.11.1,<=0.12.0")
|
check_version("peft>=0.11.1,<=0.12.0")
|
||||||
check_version("trl>=0.8.6,<=0.9.6")
|
check_version("trl>=0.8.6,<=0.9.6")
|
||||||
|
if is_transformers_version_greater_than("4.46.0") and not is_transformers_version_greater_than("4.48.1"):
|
||||||
|
logger.warning_rank0_once("There are known bugs in transformers v4.46.0-v4.48.0, please use other versions.")
|
||||||
|
|
||||||
|
|
||||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||||
|
@ -87,11 +87,6 @@ def is_transformers_version_greater_than(content: str):
|
|||||||
return _get_package_version("transformers") >= version.parse(content)
|
return _get_package_version("transformers") >= version.parse(content)
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
|
||||||
def is_transformers_version_equal_to_4_46():
|
|
||||||
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
|
|
||||||
|
|
||||||
|
|
||||||
def is_uvicorn_available():
|
def is_uvicorn_available():
|
||||||
return _is_package_available("uvicorn")
|
return _is_package_available("uvicorn")
|
||||||
|
|
||||||
|
@ -350,7 +350,7 @@ def llama_sdpa_attention_forward(
|
|||||||
|
|
||||||
|
|
||||||
def _apply_llama_patch() -> None:
|
def _apply_llama_patch() -> None:
|
||||||
check_version("transformers>=4.41.2,<=4.46.1")
|
check_version("transformers>=4.41.2,<4.48.0")
|
||||||
LlamaAttention.forward = llama_attention_forward
|
LlamaAttention.forward = llama_attention_forward
|
||||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||||
|
@ -118,6 +118,6 @@ def configure_packing(model_args: "ModelArguments", is_trainable: bool) -> None:
|
|||||||
if not is_trainable or not model_args.block_diag_attn:
|
if not is_trainable or not model_args.block_diag_attn:
|
||||||
return
|
return
|
||||||
|
|
||||||
check_version("transformers>=4.43.0,<=4.46.1")
|
check_version("transformers>=4.43.0,<=4.48.1")
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
transformers.modeling_flash_attention_utils._get_unpad_data = get_unpad_data
|
||||||
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||||
|
@ -29,7 +29,7 @@ from trl.trainer import disable_dropout_in_model
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||||
|
|
||||||
@ -282,19 +282,12 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
Subclass and override to accept extra kwargs.
|
||||||
"""
|
"""
|
||||||
loss = super().compute_loss(model, inputs, return_outputs)
|
return super().compute_loss(model, inputs, return_outputs)
|
||||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
|
||||||
if return_outputs:
|
|
||||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
|
||||||
else:
|
|
||||||
loss = loss / self.args.gradient_accumulation_steps
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def log(self, logs: Dict[str, float]) -> None:
|
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||||
r"""
|
r"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
"""
|
"""
|
||||||
@ -318,4 +311,4 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
if not key.startswith("dummy_"):
|
if not key.startswith("dummy_"):
|
||||||
logs[key] = metric
|
logs[key] = metric
|
||||||
|
|
||||||
return Trainer.log(self, logs)
|
return Trainer.log(self, logs, *args, **kwargs)
|
||||||
|
@ -28,7 +28,7 @@ from trl.trainer import disable_dropout_in_model
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
|
||||||
|
|
||||||
@ -256,19 +256,12 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
r"""
|
r"""
|
||||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
Subclass and override to accept extra kwargs.
|
||||||
"""
|
"""
|
||||||
loss = super().compute_loss(model, inputs, return_outputs)
|
return super().compute_loss(model, inputs, return_outputs)
|
||||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
|
||||||
if return_outputs:
|
|
||||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
|
||||||
else:
|
|
||||||
loss = loss / self.args.gradient_accumulation_steps
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def log(self, logs: Dict[str, float]) -> None:
|
def log(self, logs: Dict[str, float], *args, **kwargs) -> None:
|
||||||
r"""
|
r"""
|
||||||
Log `logs` on the various objects watching training, including stored metrics.
|
Log `logs` on the various objects watching training, including stored metrics.
|
||||||
"""
|
"""
|
||||||
@ -304,4 +297,4 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
if not key.startswith("dummy_"):
|
if not key.startswith("dummy_"):
|
||||||
logs[key] = metric
|
logs[key] = metric
|
||||||
|
|
||||||
return Trainer.log(self, logs)
|
return Trainer.log(self, logs, *args, **kwargs)
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
@ -25,7 +25,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
@ -72,21 +72,3 @@ class CustomTrainer(Trainer):
|
|||||||
return torch.utils.data.SequentialSampler(self.train_dataset)
|
return torch.utils.data.SequentialSampler(self.train_dataset)
|
||||||
|
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
@override
|
|
||||||
def compute_loss(
|
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
|
||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
|
||||||
r"""
|
|
||||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
|
||||||
|
|
||||||
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
|
|
||||||
"""
|
|
||||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
|
||||||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
|
||||||
if return_outputs:
|
|
||||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
|
||||||
else:
|
|
||||||
loss = loss / self.args.gradient_accumulation_steps
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
@ -25,7 +25,7 @@ from transformers import Trainer
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras import logging
|
from ...extras import logging
|
||||||
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
|
from ...extras.packages import is_transformers_version_greater_than
|
||||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
|
|
||||||
@ -107,10 +107,6 @@ class PairwiseTrainer(Trainer):
|
|||||||
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
||||||
|
|
||||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||||
|
|
||||||
if is_transformers_version_equal_to_4_46() and kwargs.get("num_items_in_batch"):
|
|
||||||
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0-4.46.1
|
|
||||||
|
|
||||||
if return_outputs:
|
if return_outputs:
|
||||||
return loss, (loss, chosen_scores, rejected_scores)
|
return loss, (loss, chosen_scores, rejected_scores)
|
||||||
else:
|
else:
|
||||||
|
@ -34,7 +34,7 @@ from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
|
from transformers import PreTrainedTokenizer, ProcessorMixin
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
@ -88,24 +88,6 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
|
|
||||||
return super()._get_train_sampler()
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
@override
|
|
||||||
def compute_loss(
|
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs
|
|
||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
|
||||||
r"""
|
|
||||||
Fixes the loss value. See https://github.com/huggingface/transformers/pull/35438 for details.
|
|
||||||
|
|
||||||
It should be removed after https://github.com/huggingface/transformers/pull/35651 is merged.
|
|
||||||
"""
|
|
||||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
|
||||||
if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False):
|
|
||||||
if return_outputs:
|
|
||||||
loss = (loss[0] / self.args.gradient_accumulation_steps, *loss[1:])
|
|
||||||
else:
|
|
||||||
loss = loss / self.args.gradient_accumulation_steps
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self,
|
self,
|
||||||
|
@ -23,7 +23,7 @@ from transformers.utils import is_torch_npu_available
|
|||||||
|
|
||||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
from ..extras.misc import is_gpu_or_npu_available, torch_gc, use_ray
|
||||||
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
|
from ..extras.packages import is_gradio_available
|
||||||
from .common import (
|
from .common import (
|
||||||
DEFAULT_CACHE_DIR,
|
DEFAULT_CACHE_DIR,
|
||||||
DEFAULT_CONFIG_DIR,
|
DEFAULT_CONFIG_DIR,
|
||||||
@ -180,7 +180,7 @@ class Runner:
|
|||||||
plot_loss=True,
|
plot_loss=True,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
ddp_timeout=180000000,
|
ddp_timeout=180000000,
|
||||||
include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True, # FIXME
|
include_num_input_tokens_seen=True,
|
||||||
)
|
)
|
||||||
args.update(json.loads(get("train.extra_args")))
|
args.update(json.loads(get("train.extra_args")))
|
||||||
|
|
||||||
|
@ -14,8 +14,10 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||||
|
|
||||||
|
from llamafactory.extras.packages import is_transformers_version_greater_than
|
||||||
from llamafactory.train.test_utils import load_infer_model
|
from llamafactory.train.test_utils import load_infer_model
|
||||||
|
|
||||||
|
|
||||||
@ -27,6 +29,7 @@ INFER_ARGS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(is_transformers_version_greater_than("4.48"), reason="Attention refactor.")
|
||||||
def test_attention():
|
def test_attention():
|
||||||
attention_available = ["disabled"]
|
attention_available = ["disabled"]
|
||||||
if is_torch_sdpa_available():
|
if is_torch_sdpa_available():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user