From 8185eb1890286515038aa629597bdc5d0b4a4355 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 30 Oct 2024 08:56:46 +0000 Subject: [PATCH] fix incorrect loss value for vlms Former-commit-id: 0aa29a71ce958343a2086090d647eb63b8f5f5be --- requirements.txt | 4 ++-- src/llamafactory/__init__.py | 8 ++++---- src/llamafactory/extras/misc.py | 8 ++++---- src/llamafactory/extras/packages.py | 2 +- src/llamafactory/model/model_utils/longlora.py | 2 +- src/llamafactory/model/model_utils/packing.py | 2 +- src/llamafactory/train/dpo/trainer.py | 4 ++-- src/llamafactory/train/kto/trainer.py | 4 ++-- src/llamafactory/train/ppo/trainer.py | 2 +- src/llamafactory/train/pt/trainer.py | 15 ++++++++++++++- src/llamafactory/train/rm/trainer.py | 4 ++-- src/llamafactory/train/sft/trainer.py | 15 ++++++++++++++- 12 files changed, 48 insertions(+), 22 deletions(-) diff --git a/requirements.txt b/requirements.txt index 126316fe..6d547813 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -transformers>=4.41.2,<=4.46.0 -datasets>=2.16.0,<=2.21.0 +transformers>=4.41.2,<=4.46.1 +datasets>=2.16.0,<=3.0.2 accelerate>=0.34.0,<=1.0.1 peft>=0.11.1,<=0.12.0 trl>=0.8.6,<=0.9.6 diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index fe40fd79..42b19b12 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -20,17 +20,17 @@ Level: Dependency graph: main: - transformers>=4.41.2,<=4.46.0 - datasets>=2.16.0,<=2.21.0 + transformers>=4.41.2,<=4.46.1 + datasets>=2.16.0,<=3.0.2 accelerate>=0.34.0,<=1.0.1 peft>=0.11.1,<=0.12.0 trl>=0.8.6,<=0.9.6 attention: transformers>=4.42.4 (gemma+fa2) longlora: - transformers>=4.41.2,<=4.46.0 + transformers>=4.41.2,<=4.46.1 packing: - transformers>=4.41.2,<=4.46.0 + transformers>=4.41.2,<=4.46.1 Disable version checking: DISABLE_VERSION_CHECK=1 Enable VRAM recording: RECORD_VRAM=1 diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index a44ad8fe..52d43341 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -79,8 +79,8 @@ def check_dependencies() -> None: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: - require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") - require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") + require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") + require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2") require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") @@ -237,7 +237,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: if use_modelscope(): require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") - from modelscope import snapshot_download + from modelscope import snapshot_download # type: ignore revision = "master" if model_args.model_revision == "main" else model_args.model_revision return snapshot_download( @@ -248,7 +248,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: if use_openmind(): require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") - from openmind.utils.hub import snapshot_download + from openmind.utils.hub import snapshot_download # type: ignore return snapshot_download( model_args.model_name_or_path, diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 99ddbbe7..98066714 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -81,7 +81,7 @@ def is_transformers_version_greater_than_4_43(): @lru_cache def is_transformers_version_equal_to_4_46(): - return _get_package_version("transformers") == version.parse("4.46.0") + return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1") def is_uvicorn_available(): diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 215a8ada..8796b197 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -353,7 +353,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") + require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 104fa5a7..0fdb0e06 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor def _patch_for_block_diag_attn(model_type: str) -> None: - require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") + require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") if is_transformers_version_greater_than_4_43(): import transformers.modeling_flash_attention_utils diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 620c5313..482afa1d 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -101,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer): self.callback_handler.add_callback(PissaConvertCallback) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -274,7 +274,7 @@ class CustomDPOTrainer(DPOTrainer): https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 """ loss = super().compute_loss(model, inputs, return_outputs) - if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): + if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): loss /= self.args.gradient_accumulation_steps return loss diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 88f6d4cc..fd93974d 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -96,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer): self.add_callback(SaveProcessorCallback(processor)) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -247,7 +247,7 @@ class CustomKTOTrainer(KTOTrainer): https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 """ loss = super().compute_loss(model, inputs, return_outputs) - if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): + if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): loss /= self.args.gradient_accumulation_steps return loss diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index d1510c47..52e8ac51 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.add_callback(SaveProcessorCallback(processor)) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index eebdb179..333f8fa5 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -19,6 +19,7 @@ from transformers import Trainer from typing_extensions import override from ...extras.logging import get_logger +from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -51,7 +52,7 @@ class CustomTrainer(Trainer): self.add_callback(PissaConvertCallback) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -68,3 +69,15 @@ class CustomTrainer(Trainer): ) -> "torch.optim.lr_scheduler.LRScheduler": create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + + @override + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + r""" + Fixes the loss value for transformers 4.46.0. + https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 + """ + loss = super().compute_loss(model, inputs, return_outputs, **kwargs) + if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): + loss /= self.args.gradient_accumulation_steps # other model should not scale the loss + + return loss diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 311b9005..2cb6ebb3 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -60,7 +60,7 @@ class PairwiseTrainer(Trainer): self.add_callback(PissaConvertCallback) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -100,7 +100,7 @@ class PairwiseTrainer(Trainer): loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() - if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): + if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0 if return_outputs: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 609f0f06..573c716e 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -27,6 +27,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger +from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.add_callback(PissaConvertCallback) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -78,6 +79,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + @override + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + r""" + Fixes the loss value for transformers 4.46.0. + https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 + """ + loss = super().compute_loss(model, inputs, return_outputs, **kwargs) + if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): + loss /= self.args.gradient_accumulation_steps # other model should not scale the loss + + return loss + @override def prediction_step( self,