fix layer norm dtype

This commit is contained in:
hiyouga
2023-09-28 00:25:55 +08:00
parent b0b0138e1d
commit 84b7486885
6 changed files with 28 additions and 22 deletions

View File

@@ -67,6 +67,10 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
default="auto",
metadata={"help": "Data type of the layer norm weights."}
)
def __post_init__(self):
self.compute_dtype = None