fix ppo args

This commit is contained in:
hiyouga
2023-10-11 23:40:50 +08:00
parent 2818af0b09
commit 11bd271364
4 changed files with 18 additions and 9 deletions

View File

@@ -201,7 +201,9 @@ def get_train_args(
)
# postprocess model_args
model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32)
)
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary: