Former-commit-id: 5cc7a447843c578af602a5e054348fad1c9306ce
This commit is contained in:
hiyouga 2023-09-22 15:00:48 +08:00
parent db21953bf0
commit e930682152
3 changed files with 18 additions and 12 deletions

View File

@ -369,8 +369,7 @@ python src/export_model.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export \ --output_dir path_to_export
--fp16
``` ```
### API Demo ### API Demo

View File

@ -368,8 +368,7 @@ python src/export_model.py \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export \ --output_dir path_to_export
--fp16
``` ```
### API 服务 ### API 服务

View File

@ -9,10 +9,12 @@ from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
try: try:
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
is_fp16_available = is_torch_cuda_available()
is_bf16_available = is_torch_bf16_gpu_available() is_bf16_available = is_torch_bf16_gpu_available()
is_npu_available = is_torch_npu_available() is_npu_available = is_torch_npu_available()
except ImportError: except ImportError:
is_fp16_available = torch.cuda.is_available()
is_bf16_available = torch.cuda.is_bf16_supported() is_bf16_available = torch.cuda.is_bf16_supported()
is_npu_available = False is_npu_available = False
@ -29,6 +31,17 @@ from llmtuner.hparams import (
logger = get_logger(__name__) logger = get_logger(__name__)
def _infer_dtype() -> torch.dtype:
if is_npu_available:
return torch.float16
elif is_bf16_available:
return torch.bfloat16
elif is_fp16_available:
return torch.float16
else:
return torch.float32
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None: if args is not None:
return parser.parse_dict(args) return parser.parse_dict(args)
@ -211,7 +224,7 @@ def get_train_args(
elif training_args.fp16: elif training_args.fp16:
model_args.compute_dtype = torch.float16 model_args.compute_dtype = torch.float16
else: else:
model_args.compute_dtype = torch.float32 model_args.compute_dtype = _infer_dtype()
model_args.model_max_length = data_args.cutoff_len model_args.model_max_length = data_args.cutoff_len
@ -252,11 +265,6 @@ def get_infer_args(
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
# auto-detect cuda capability # auto-detect cuda capability
if is_npu_available: model_args.compute_dtype = _infer_dtype()
model_args.compute_dtype = torch.float16
elif is_bf16_available:
model_args.compute_dtype = torch.bfloat16
else:
model_args.compute_dtype = torch.float16
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args