support bf16 ppo #551

Former-commit-id: d125218cde893c7c8527ab27b4d2dfb2474c384d
This commit is contained in:
hiyouga 2023-08-18 00:40:32 +08:00
parent caf4a61e21
commit 66771352bb

View File

@ -151,21 +151,19 @@ def get_train_args(
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if training_args.fp16:
model_args.compute_dtype = torch.float16
elif training_args.bf16:
if training_args.bf16:
if not torch.cuda.is_bf16_supported():
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float32
model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), training_args.fp16
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
))
logger.info(f"Training/evaluation parameters {training_args}")