From 5226c4fa97c07511d677767341171e455dcd44c2 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Tue, 16 Apr 2024 17:29:52 +0800 Subject: [PATCH] Update trainer.py Former-commit-id: 6700a1b9fa0cbd965ac45d3f2de1088727235c25 --- src/llmtuner/train/sft/trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/train/sft/trainer.py b/src/llmtuner/train/sft/trainer.py index de741426..def427fd 100644 --- a/src/llmtuner/train/sft/trainer.py +++ b/src/llmtuner/train/sft/trainer.py @@ -1,5 +1,6 @@ import json import os +from types import MethodType from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -9,8 +10,7 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger from ..utils import create_custom_optimzer, create_custom_scheduler -from types import MethodType -from packaging import version + if TYPE_CHECKING: from transformers.trainer import PredictionOutput @@ -31,6 +31,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.finetuning_args = finetuning_args if finetuning_args.use_badam: from badam import clip_grad_norm_for_sparse_tensor + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) def create_optimizer(self) -> "torch.optim.Optimizer":