diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 80424e5e..9aafd3ff 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -4,7 +4,6 @@ import torch from typing import TYPE_CHECKING, Tuple from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from transformers.utils import ( - is_torch_bf16_cpu_available, is_torch_bf16_gpu_available, is_torch_cuda_available, is_torch_npu_available, @@ -13,7 +12,7 @@ from transformers.utils import ( _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() try: - _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available() + _is_bf16_available = is_torch_bf16_gpu_available() except: _is_bf16_available = False