add test cases

This commit is contained in:
hiyouga
2024-06-15 04:05:54 +08:00
parent 2d43b8bb49
commit b27269bd2b
9 changed files with 184 additions and 34 deletions

View File

@@ -44,7 +44,10 @@ def patch_config(
is_trainable: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
if model_args.infer_dtype == "auto":
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
else:
model_args.compute_dtype = getattr(torch, model_args.infer_dtype)
if is_torch_npu_available():
use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"]