refactor model_dtype, fix PPO trainer

Former-commit-id: 2818af0b09
This commit is contained in:
hiyouga
2023-10-11 23:16:01 +08:00
parent c350ba0f05
commit c9d1cd108d
10 changed files with 104 additions and 119 deletions

View File

@@ -145,6 +145,9 @@ class Runner:
)
args[compute_type] = True
if args["quantization_bit"] is not None:
args["upcast_layernorm"] = True
if args["stage"] == "ppo":
args["reward_model"] = reward_model
val_size = 0