fix save function

This commit is contained in:
hiyouga
2023-07-21 14:09:07 +08:00
parent 1d1d8538c9
commit d2f18197e3
2 changed files with 4 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
import os
import torch
from typing import Dict
from typing import Dict, Optional
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from transformers.modeling_utils import load_sharded_checkpoint
@@ -12,12 +12,12 @@ from llmtuner.extras.logging import get_logger
logger = get_logger(__name__)
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get state dict containing trainable parameters
def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]:
state_dict = model.state_dict()
filtered_state_dict = {}
for k, v in model.named_parameters():
if v.requires_grad:
if (not trainable_only) or v.requires_grad:
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
return filtered_state_dict