Merge pull request #6547 from hiyouga/hiyouga/fix_pixtral_dpo

[trainer] fix pixtral dpo

Former-commit-id: c973f32849b979a3ebb80caa01029b43fbb620ac
This commit is contained in:
hoshi-hiyouga 2025-01-07 14:38:55 +08:00 committed by GitHub
commit a0bcac80c0
3 changed files with 32 additions and 10 deletions

View File

@ -31,7 +31,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING:
@ -193,7 +193,7 @@ class CustomDPOTrainer(DPOTrainer):
Otherwise the average log probabilities.
"""
if self.finetuning_args.use_ref_model:
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
batch = nested_detach(batch, clone=True) # avoid error
all_logits: "torch.Tensor" = model(**batch, return_dict=True, use_cache=False).logits.to(torch.float32)
all_logps, valid_length = get_batch_logps(logits=all_logits, labels=batch["labels"])

View File

@ -30,7 +30,7 @@ from typing_extensions import override
from ...extras.constants import IGNORE_INDEX
from ...extras.packages import is_transformers_version_equal_to_4_46, is_transformers_version_greater_than
from ..callbacks import SaveProcessorCallback
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler, get_batch_logps, nested_detach
if TYPE_CHECKING:
@ -142,7 +142,7 @@ class CustomKTOTrainer(KTOTrainer):
r"""
Runs forward pass and computes the log probabilities.
"""
batch = {k: v.detach().clone() for k, v in batch.items()} # avoid error
batch = nested_detach(batch, clone=True) # avoid error
model_inputs = {
"input_ids": batch[f"{prefix}input_ids"],
"attention_mask": batch[f"{prefix}attention_mask"],

View File

@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
import torch
@ -36,7 +37,7 @@ from ..model import find_all_linear_modules, load_model, load_tokenizer, load_va
if is_galore_available():
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit
from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit # type: ignore
if TYPE_CHECKING:
@ -330,7 +331,7 @@ def _create_badam_optimizer(
]
if finetuning_args.badam_mode == "layer":
from badam import BlockOptimizer
from badam import BlockOptimizer # type: ignore
base_optimizer = optim_class(param_groups, **optim_kwargs)
optimizer = BlockOptimizer(
@ -350,7 +351,7 @@ def _create_badam_optimizer(
)
elif finetuning_args.badam_mode == "ratio":
from badam import BlockOptimizerRatio
from badam import BlockOptimizerRatio # type: ignore
assert finetuning_args.badam_update_ratio > 1e-6
optimizer = BlockOptimizerRatio(
@ -374,7 +375,7 @@ def _create_adam_mini_optimizer(
model: "PreTrainedModel",
training_args: "Seq2SeqTrainingArguments",
) -> "torch.optim.Optimizer":
from adam_mini import Adam_mini
from adam_mini import Adam_mini # type: ignore
hidden_size = getattr(model.config, "hidden_size", None)
num_q_head = getattr(model.config, "num_attention_heads", None)
@ -459,12 +460,33 @@ def get_batch_logps(
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)
def nested_detach(
tensors: Union["torch.Tensor", List["torch.Tensor"], Tuple["torch.Tensor"], Dict[str, "torch.Tensor"]],
clone: bool = False,
):
r"""
Detach `tensors` (even if it's a nested list/tuple/dict of tensors).
"""
if isinstance(tensors, (list, tuple)):
return type(tensors)(nested_detach(t, clone=clone) for t in tensors)
elif isinstance(tensors, Mapping):
return type(tensors)({k: nested_detach(t, clone=clone) for k, t in tensors.items()})
if isinstance(tensors, torch.Tensor):
if clone:
return tensors.detach().clone()
else:
return tensors.detach()
else:
return tensors
def get_swanlab_callback(finetuning_args: "FinetuningArguments") -> "TrainerCallback":
r"""
Gets the callback for logging to SwanLab.
"""
import swanlab
from swanlab.integration.transformers import SwanLabCallback
import swanlab # type: ignore
from swanlab.integration.transformers import SwanLabCallback # type: ignore
if finetuning_args.swanlab_api_key is not None:
swanlab.login(api_key=finetuning_args.swanlab_api_key)