support export model on Ascend NPU

This commit is contained in:
statelesshz
2023-09-20 10:15:59 +08:00
parent 10ab2f8b90
commit b3e41c6d49

View File

@@ -13,7 +13,7 @@ from transformers import (
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase PreTrainedTokenizerBase
) )
from transformers.utils import check_min_version from transformers.utils import check_min_version, is_torch_npu_available
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@@ -215,6 +215,9 @@ def load_model_and_tokenizer(
# Prepare model for inference # Prepare model for inference
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
if is_torch_npu_available():
infer_dtype = torch.float16
else:
infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
model = model.to(infer_dtype) if model_args.quantization_bit is None else model model = model.to(infer_dtype) if model_args.quantization_bit is None else model