mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[model] add llama4 (#7611)
This commit is contained in:
@@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
||||
|
||||
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||
module: torch.nn.Module = func.__self__
|
||||
if isinstance(func, partial):
|
||||
module: torch.nn.Module = func.func.__self__
|
||||
else:
|
||||
module: torch.nn.Module = func.__self__
|
||||
|
||||
has_grad = False
|
||||
if any(param.requires_grad for param in module.parameters()):
|
||||
|
||||
Reference in New Issue
Block a user