diff --git a/src/train.py b/src/train.py index 098ec1b5..00a7fa26 100644 --- a/src/train.py +++ b/src/train.py @@ -1,5 +1,5 @@ import os -from torch_npu.contrib import transfer_to_npu +from transformers import is_torch_npu_available from llmtuner.train.tuner import run_exp