This commit is contained in:
hiyouga
2023-10-15 18:28:45 +08:00
parent 0d63584c03
commit a6a04be2e6
9 changed files with 40 additions and 57 deletions

View File

@@ -186,7 +186,7 @@ def get_train_args(
# postprocess model_args
model_args.compute_dtype = (
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else torch.float32)
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
)
model_args.model_max_length = data_args.cutoff_len