update tests

Former-commit-id: 93d3b8f43f
This commit is contained in:
hiyouga
2024-11-02 12:21:41 +08:00
parent 25093c2d82
commit 3f7c874594
23 changed files with 53 additions and 62 deletions

View File

@@ -19,7 +19,7 @@
# limitations under the License.
import inspect
from functools import partial, wraps
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
@@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
Only applies gradient checkpointing to trainable layers.
"""
@wraps(gradient_checkpointing_func)
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: "torch.nn.Module" = func.__self__
@@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
return gradient_checkpointing_func(func, *args, **kwargs)
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
return custom_gradient_checkpointing_func