mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
@@ -19,7 +19,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from functools import partial, wraps
|
||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
@@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
Only applies gradient checkpointing to trainable layers.
|
||||
"""
|
||||
|
||||
@wraps(gradient_checkpointing_func)
|
||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: "torch.nn.Module" = func.__self__
|
||||
|
||||
@@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
|
||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
||||
|
||||
return custom_gradient_checkpointing_func
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user