From 66771352bbe4e8d2321a5938808bd5b3abd8dab1 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 18 Aug 2023 00:40:32 +0800 Subject: [PATCH] support bf16 ppo #551 Former-commit-id: d125218cde893c7c8527ab27b4d2dfb2474c384d --- src/llmtuner/tuner/core/parser.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/tuner/core/parser.py index 7dda0a5c..e039513d 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/tuner/core/parser.py @@ -151,21 +151,19 @@ def get_train_args( training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning - if training_args.fp16: - model_args.compute_dtype = torch.float16 - elif training_args.bf16: + if training_args.bf16: if not torch.cuda.is_bf16_supported(): raise ValueError("Current device does not support bf16 training.") model_args.compute_dtype = torch.bfloat16 else: - model_args.compute_dtype = torch.float32 + model_args.compute_dtype = torch.float16 model_args.model_max_length = data_args.max_source_length + data_args.max_target_length # Log on each process the small summary: - logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format( + logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format( training_args.local_rank, training_args.device, training_args.n_gpu, - bool(training_args.local_rank != -1), training_args.fp16 + bool(training_args.local_rank != -1), str(model_args.compute_dtype) )) logger.info(f"Training/evaluation parameters {training_args}")