From 4de9ef568a7827a574e748b4de53cf4bf7df8e93 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 5 Jul 2023 15:13:00 +0800 Subject: [PATCH] fix compute dtype Former-commit-id: 5aadbb22730d19570b039462c91df443dbb34b5f --- src/utils/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/common.py b/src/utils/common.py index 1086be35..917bd867 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -213,7 +213,7 @@ def load_pretrained( model = AutoModelForCausalLM.from_pretrained( model_to_load, config=config, - torch_dtype=model_args.compute_dtype, + torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, low_cpu_mem_usage=True, **config_kwargs )