From 75829c8699e33db58a8d1daf4e6e0740e8bc2377 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sun, 24 Mar 2024 00:34:54 +0800 Subject: [PATCH] fix #2928 Former-commit-id: 7afbc85daee295cf38dcee9ded5afd87b2c4cfd1 --- src/llmtuner/model/loader.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index c7ffb675..b1816aa7 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -110,6 +110,9 @@ def load_model( if not is_trainable: model.requires_grad_(False) model.eval() + for param in model.parameters(): + if param.device.type == "cuda": + param.data = param.data.to(model_args.compute_dtype) else: model.train()