diff --git a/src/llamafactory/model/utils/quantization.py b/src/llamafactory/model/utils/quantization.py index 95412e7c..161ad5aa 100644 --- a/src/llamafactory/model/utils/quantization.py +++ b/src/llamafactory/model/utils/quantization.py @@ -131,7 +131,7 @@ def configure_quantization( bnb_4bit_compute_dtype=model_args.compute_dtype, bnb_4bit_use_double_quant=model_args.double_quantization, bnb_4bit_quant_type=model_args.quantization_type, - bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp qlora + bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora ) if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": @@ -141,6 +141,7 @@ def configure_quantization( require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + init_kwargs["torch_dtype"] = model_args.compute_dtype # fsdp+qlora requires same dtype else: init_kwargs["device_map"] = {"": get_current_device()}