mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 03:40:34 +08:00
@@ -89,6 +89,9 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case
|
||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
||||
|
||||
return custom_gradient_checkpointing_func
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user