fix mod stuff

This commit is contained in:
hiyouga
2024-04-21 18:11:10 +08:00
parent d0273787be
commit f58425ab45
16 changed files with 63 additions and 88 deletions

View File

@@ -28,6 +28,8 @@ LOG_FILE_NAME = "trainer_log.jsonl"
METHODS = ["full", "freeze", "lora"]
MOD_SUPPORTED_MODELS = ["bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"]
PEFT_METHODS = ["lora"]
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]

View File

@@ -83,6 +83,8 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
num_bytes = param.quant_storage.itemsize
elif hasattr(param, "element_size"): # for older pytorch version
num_bytes = param.element_size()
else:
num_bytes = 1