From f51b435bcfb04eb919b119a274f03f4b399b5981 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 25 Jun 2024 02:34:04 +0800 Subject: [PATCH] fix #4432 Former-commit-id: 972a3b469c600bc6528aef3a49b6fdec63d65803 --- src/llamafactory/model/loader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 69cccd93..e1015821 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict +import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead @@ -175,6 +176,10 @@ def load_model( if not is_trainable: model.requires_grad_(False) + for param in model.parameters(): + if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32: + param.data = param.data.to(model_args.compute_dtype) + model.eval() else: model.train()