diff --git a/src/train.py b/src/train.py index 00a7fa26..4cc21194 100644 --- a/src/train.py +++ b/src/train.py @@ -1,5 +1,8 @@ import os + +import torch from transformers import is_torch_npu_available + from llmtuner.train.tuner import run_exp