mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
parent
75829c8699
commit
84c3d509fa
@ -81,7 +81,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
|
||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
num_bytes = param.quant_storage.itemsize if hasattr(param, "quant_storage") else 1
|
||||
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||
num_bytes = param.quant_storage.itemsize
|
||||
else:
|
||||
num_bytes = 1
|
||||
|
||||
num_params = num_params * 2 * num_bytes
|
||||
|
||||
all_param += num_params
|
||||
|
Loading…
x
Reference in New Issue
Block a user