mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 06:12:50 +08:00
23 lines
426 B
Python
23 lines
426 B
Python
import os
|
|
|
|
import torch
|
|
from transformers import is_torch_npu_available
|
|
|
|
from llmtuner.train.tuner import run_exp
|
|
|
|
|
|
def main():
|
|
run_exp()
|
|
|
|
|
|
def _mp_fn(index):
|
|
# For xla_spawn (TPUs)
|
|
run_exp()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if is_torch_npu_available():
|
|
use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1']
|
|
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
|
|
main()
|