mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
parent
f9c859e97b
commit
0170ef83a6
@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
@ -129,6 +130,10 @@ def gradient_checkpointing_enable(
|
|||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
|
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||||
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
|
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||||
|
else:
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user