mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
parent
db21953bf0
commit
e930682152
@ -369,8 +369,7 @@ python src/export_model.py \
|
|||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_export \
|
--output_dir path_to_export
|
||||||
--fp16
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### API Demo
|
### API Demo
|
||||||
|
@ -368,8 +368,7 @@ python src/export_model.py \
|
|||||||
--template default \
|
--template default \
|
||||||
--finetuning_type lora \
|
--finetuning_type lora \
|
||||||
--checkpoint_dir path_to_checkpoint \
|
--checkpoint_dir path_to_checkpoint \
|
||||||
--output_dir path_to_export \
|
--output_dir path_to_export
|
||||||
--fp16
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### API 服务
|
### API 服务
|
||||||
|
@ -9,10 +9,12 @@ from transformers.utils.versions import require_version
|
|||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
|
||||||
|
is_fp16_available = is_torch_cuda_available()
|
||||||
is_bf16_available = is_torch_bf16_gpu_available()
|
is_bf16_available = is_torch_bf16_gpu_available()
|
||||||
is_npu_available = is_torch_npu_available()
|
is_npu_available = is_torch_npu_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
is_fp16_available = torch.cuda.is_available()
|
||||||
is_bf16_available = torch.cuda.is_bf16_supported()
|
is_bf16_available = torch.cuda.is_bf16_supported()
|
||||||
is_npu_available = False
|
is_npu_available = False
|
||||||
|
|
||||||
@ -29,6 +31,17 @@ from llmtuner.hparams import (
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_dtype() -> torch.dtype:
|
||||||
|
if is_npu_available:
|
||||||
|
return torch.float16
|
||||||
|
elif is_bf16_available:
|
||||||
|
return torch.bfloat16
|
||||||
|
elif is_fp16_available:
|
||||||
|
return torch.float16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return parser.parse_dict(args)
|
return parser.parse_dict(args)
|
||||||
@ -211,7 +224,7 @@ def get_train_args(
|
|||||||
elif training_args.fp16:
|
elif training_args.fp16:
|
||||||
model_args.compute_dtype = torch.float16
|
model_args.compute_dtype = torch.float16
|
||||||
else:
|
else:
|
||||||
model_args.compute_dtype = torch.float32
|
model_args.compute_dtype = _infer_dtype()
|
||||||
|
|
||||||
model_args.model_max_length = data_args.cutoff_len
|
model_args.model_max_length = data_args.cutoff_len
|
||||||
|
|
||||||
@ -252,11 +265,6 @@ def get_infer_args(
|
|||||||
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
|
||||||
|
|
||||||
# auto-detect cuda capability
|
# auto-detect cuda capability
|
||||||
if is_npu_available:
|
model_args.compute_dtype = _infer_dtype()
|
||||||
model_args.compute_dtype = torch.float16
|
|
||||||
elif is_bf16_available:
|
|
||||||
model_args.compute_dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
model_args.compute_dtype = torch.float16
|
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
Loading…
x
Reference in New Issue
Block a user