Former-commit-id: 9ae646fbbd809057a9c54fe41e1ae5a07a674556
This commit is contained in:
hiyouga 2024-03-24 00:43:21 +08:00
parent c548ad5e69
commit 92248f9cb2

View File

@ -81,7 +81,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
if param.__class__.__name__ == "Params4bit": if param.__class__.__name__ == "Params4bit":
num_bytes = param.quant_storage.itemsize if hasattr(param, "quant_storage") else 1 if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
else:
num_bytes = 1
num_params = num_params * 2 * num_bytes num_params = num_params * 2 * num_bytes
all_param += num_params all_param += num_params