mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
@@ -151,13 +151,16 @@ def get_train_args(
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
if training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
if not torch.cuda.is_bf16_supported():
|
||||
raise ValueError("Current device does not support bf16 training.")
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
|
||||
|
||||
Reference in New Issue
Block a user