Former-commit-id: 140ad4ad567de8817a14972175e668971bae6a0a
This commit is contained in:
hiyouga 2024-03-24 00:43:21 +08:00
parent 75829c8699
commit 84c3d509fa

View File

@ -81,7 +81,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit":
num_bytes = param.quant_storage.itemsize if hasattr(param, "quant_storage") else 1
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
else:
num_bytes = 1
num_params = num_params * 2 * num_bytes
all_param += num_params