mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 11:42:49 +08:00
78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
|
#
|
|
# This code is inspired by the HuggingFace's transformers library.
|
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import platform
|
|
|
|
import accelerate
|
|
import datasets
|
|
import peft
|
|
import torch
|
|
import transformers
|
|
import trl
|
|
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
|
|
|
|
|
VERSION = "0.9.2"
|
|
|
|
|
|
def print_env() -> None:
|
|
info = {
|
|
"`llamafactory` version": VERSION,
|
|
"Platform": platform.platform(),
|
|
"Python version": platform.python_version(),
|
|
"PyTorch version": torch.__version__,
|
|
"Transformers version": transformers.__version__,
|
|
"Datasets version": datasets.__version__,
|
|
"Accelerate version": accelerate.__version__,
|
|
"PEFT version": peft.__version__,
|
|
"TRL version": trl.__version__,
|
|
}
|
|
|
|
if is_torch_cuda_available():
|
|
info["PyTorch version"] += " (GPU)"
|
|
info["GPU type"] = torch.cuda.get_device_name()
|
|
info["GPU number"] = torch.cuda.device_count()
|
|
info["GPU memory"] = f"{torch.cuda.mem_get_info()[1] / (1024**3):.2f}GB"
|
|
|
|
if is_torch_npu_available():
|
|
info["PyTorch version"] += " (NPU)"
|
|
info["NPU type"] = torch.npu.get_device_name()
|
|
info["CANN version"] = torch.version.cann
|
|
|
|
try:
|
|
import deepspeed # type: ignore
|
|
|
|
info["DeepSpeed version"] = deepspeed.__version__
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
import bitsandbytes # type: ignore
|
|
|
|
info["Bitsandbytes version"] = bitsandbytes.__version__
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
import vllm
|
|
|
|
info["vLLM version"] = vllm.__version__
|
|
except Exception:
|
|
pass
|
|
|
|
print("\n" + "\n".join([f"- {key}: {value}" for key, value in info.items()]) + "\n")
|