mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
fix test case
This commit is contained in:
@@ -89,6 +89,9 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
|
||||
|
||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||
|
||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix test case
|
||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
||||
|
||||
return custom_gradient_checkpointing_func
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user