mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 16:18:10 +08:00
parent
c548ad5e69
commit
92248f9cb2
@ -81,7 +81,11 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
|||||||
|
|
||||||
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
|
||||||
if param.__class__.__name__ == "Params4bit":
|
if param.__class__.__name__ == "Params4bit":
|
||||||
num_bytes = param.quant_storage.itemsize if hasattr(param, "quant_storage") else 1
|
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
||||||
|
num_bytes = param.quant_storage.itemsize
|
||||||
|
else:
|
||||||
|
num_bytes = 1
|
||||||
|
|
||||||
num_params = num_params * 2 * num_bytes
|
num_params = num_params * 2 * num_bytes
|
||||||
|
|
||||||
all_param += num_params
|
all_param += num_params
|
||||||
|
Loading…
x
Reference in New Issue
Block a user