refactor model_dtype, fix PPO trainer

This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent 5310e4d182
commit 2818af0b09
10 changed files with 104 additions and 119 deletions

View File

@@ -67,9 +67,9 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
default="auto",
metadata={"help": "Data type of the layer norm weights."}
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
def __post_init__(self):