refactor model_dtype, fix PPO trainer

Former-commit-id: 2818af0b09
This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent c350ba0f05
commit c9d1cd108d
10 changed files with 104 additions and 119 deletions

View File

@@ -67,9 +67,9 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
default="auto",
metadata={"help": "Data type of the layer norm weights."}
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)
def __post_init__(self):