[model] add llama4 (#7611)

This commit is contained in:
hoshi-hiyouga
2025-04-06 13:42:31 +08:00
committed by GitHub
parent 6eb28bcacd
commit 40fb24916f
11 changed files with 167 additions and 8 deletions

View File

@@ -79,7 +79,10 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
module: torch.nn.Module = func.__self__
if isinstance(func, partial):
module: torch.nn.Module = func.func.__self__
else:
module: torch.nn.Module = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):