From 496ba46960b062b070f80a652b76257c4f01a8a9 Mon Sep 17 00:00:00 2001 From: statelesshz Date: Wed, 20 Sep 2023 10:15:59 +0800 Subject: [PATCH] support export model on Ascend NPU Former-commit-id: 50f94e6d9d62c848db7a3db85fa999d67ddd9f04 --- src/llmtuner/tuner/core/loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index 2570c5d7..6cc40a33 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -13,7 +13,7 @@ from transformers import ( PreTrainedModel, PreTrainedTokenizerBase ) -from transformers.utils import check_min_version +from transformers.utils import check_min_version, is_torch_npu_available from transformers.utils.versions import require_version from trl import AutoModelForCausalLMWithValueHead @@ -215,7 +215,10 @@ def load_model_and_tokenizer( # Prepare model for inference if not is_trainable: model.requires_grad_(False) # fix all model params - infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability + if is_torch_npu_available(): + infer_dtype = torch.float16 + else: + infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability model = model.to(infer_dtype) if model_args.quantization_bit is None else model trainable_params, all_param = count_parameters(model)