LLaMA-Factory/src/train.py
hoshi-hiyouga 082506eba8 Update train.py
Former-commit-id: 1c3c4989022025db756965350ae0381fc9db32e5
2024-05-14 20:47:52 +08:00

23 lines
426 B
Python

import os
import torch
from transformers import is_torch_npu_available
from llmtuner.train.tuner import run_exp
def main():
run_exp()
def _mp_fn(index):
# For xla_spawn (TPUs)
run_exp()
if __name__ == "__main__":
if is_torch_npu_available():
use_jit_compile = os.getenv('JIT_COMPILE', 'False').lower() in ['true', '1']
torch.npu.set_compile_mode(jit_compile=use_jit_compile)
main()