mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
[misc] upgrade cli (#7714)
This commit is contained in:
@@ -96,6 +96,7 @@ def patch_config(
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
if is_torch_npu_available():
|
||||
# avoid JIT compile on NPU devices, see https://zhuanlan.zhihu.com/p/660875458
|
||||
torch.npu.set_compile_mode(jit_compile=is_env_enabled("JIT_COMPILE"))
|
||||
|
||||
configure_attn_implementation(config, model_args, is_trainable)
|
||||
|
||||
Reference in New Issue
Block a user