From 8c57169eb7379ea6a4729bf898e2e9678209cd04 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 7 Jan 2025 06:30:44 +0000 Subject: [PATCH] fix #6546 Former-commit-id: 870f23d7eaff1e32a73fee4eb972163c85ba7b67 --- src/llamafactory/train/dpo/trainer.py | 4 +-- src/llamafactory/train/kto/trainer.py | 4 +-- src/llamafactory/train/trainer_utils.py | 34 ++++++++++++++++++++----- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index ad670385..770b32e5 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -31,7 +31,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach if TYPE_CHECKING: @@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer): Otherwise the average log probabilities. """ if self.finetuning_args.use_ref_model: - batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error + batch = nested_detach(batch, clone=True) # avoid error all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32) all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"]) diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 3c6d1089..419de579 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -30,7 +30,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback -from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach if TYPE_CHECKING: @@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer): r""" Runs forward pass and computes the log probabilities. """ - batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error + batch = nested_detach(batch, clone=True) # avoid error model_inputs = { "input_ids": batch[f"{prefix}input_ids"], "attention_mask": batch[f"{prefix}attention_mask"], diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index eb2421ce..4cd2337e 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -17,6 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Mapping from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union import torch @@ -36,7 +37,7 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va if is_galore_available(): - from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore if TYPE_CHECKING: @@ -330,7 +331,7 @@ def _create_badam_optimizer( ] if finetuning_args.badam_mode == "layer": - from badam import BlockOptimizer + from badam import BlockOptimizer # type: ignore base_optimizer = optim_class(param_groups, **optim_kwargs) optimizer = BlockOptimizer( @@ -350,7 +351,7 @@ def _create_badam_optimizer( ) elif finetuning_args.badam_mode == "ratio": - from badam import BlockOptimizerRatio + from badam import BlockOptimizerRatio # type: ignore assert finetuning_args.badam_update_ratio > 1e-6 optimizer = BlockOptimizerRatio( @@ -374,7 +375,7 @@ def _create_adam_mini_optimizer( model: "PreTrainedModel", training_args: "Seq2SeqTrainingArguments", ) -> "torch.optim.Optimizer": - from adam_mini import Adam_mini + from adam_mini import Adam_mini # type: ignore hidden_size = getattr(model.config, "hidden_size", None) num_q_head = getattr(model.config, "num_attention_heads", None) @@ -459,12 +460,33 @@ def get_batch_logps( return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) +def nested_detach( + tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]], + clone: bool = False, +): + r""" + Detach `tensors` (even if it's a nested list/tuple/dict of tensors). + """ + if isinstance(tensors, (list, tuple)): + return type(tensors)(nested_detach(t, clone=clone) for t in tensors) + elif isinstance(tensors, Mapping): + return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()}) + + if isinstance(tensors, torch.Tensor): + if clone: + return tensors.detach().clone() + else: + return tensors.detach() + else: + return tensors + + def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback": r""" Gets the callback for logging to SwanLab. """ - import swanlab - from swanlab.integration.transformers import SwanLabCallback + import swanlab # type: ignore + from swanlab.integration.transformers import SwanLabCallback # type: ignore if finetuning_args.swanlab_api_key is not None: swanlab.login(api_key=finetuning_args.swanlab_api_key)