From 1a261add61b3269eb98b726f966495408bb1018e Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sat, 8 Jun 2024 07:15:45 +0800 Subject: [PATCH] fix llamafactory-cli env Former-commit-id: 972ec9c668de1a9b6d872187dbc0c1d94f6fec6b --- src/llamafactory/extras/env.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/llamafactory/extras/env.py b/src/llamafactory/extras/env.py index 2b9c6458..1d4e43f1 100644 --- a/src/llamafactory/extras/env.py +++ b/src/llamafactory/extras/env.py @@ -6,10 +6,7 @@ import peft import torch import transformers import trl -from transformers.integrations import is_deepspeed_available -from transformers.utils import is_bitsandbytes_available, is_torch_cuda_available, is_torch_npu_available - -from .packages import is_vllm_available +from transformers.utils import is_torch_cuda_available, is_torch_npu_available VERSION = "0.8.1.dev0" @@ -37,19 +34,25 @@ def print_env() -> None: info["NPU type"] = torch.npu.get_device_name() info["CANN version"] = torch.version.cann - if is_deepspeed_available(): + try: import deepspeed # type: ignore info["DeepSpeed version"] = deepspeed.__version__ + except Exception: + pass - if is_bitsandbytes_available(): + try: import bitsandbytes info["Bitsandbytes version"] = bitsandbytes.__version__ + except Exception: + pass - if is_vllm_available(): + try: import vllm info["vLLM version"] = vllm.__version__ + except Exception: + pass print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")