fix quantization

Former-commit-id: ccb0f58e22f55b15531fd0e85f5935b150575bec
This commit is contained in:
hiyouga 2023-11-17 22:21:29 +08:00
parent f9df6c17ed
commit 0d98d1a28c

View File

@ -168,17 +168,12 @@ def load_model_and_tokenizer(
config_kwargs["device_map"] = {"": get_current_device()} config_kwargs["device_map"] = {"": get_current_device()}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if is_deepspeed_zero3_enabled() or getattr(config, "model_type", None) == "chatglm":
low_cpu_mem_usage = False
else:
low_cpu_mem_usage = True
# Load pre-trained models (without valuehead) # Load pre-trained models (without valuehead)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_to_load, model_to_load,
config=config, config=config,
torch_dtype=model_args.compute_dtype, torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs **config_kwargs
) )