diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index c7ffb675..b1816aa7 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -110,6 +110,9 @@ def load_model( if not is_trainable: model.requires_grad_(False) model.eval() + for param in model.parameters(): + if param.device.type == "cuda": + param.data = param.data.to(model_args.compute_dtype) else: model.train()