mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 16:18:10 +08:00
fix quantization
Former-commit-id: 8268aefe8fba268065e24ffe159a9c49f7c6f3a5
This commit is contained in:
parent
5ce5ea84a9
commit
a7bf0b85d7
@ -168,17 +168,12 @@ def load_model_and_tokenizer(
|
|||||||
config_kwargs["device_map"] = {"": get_current_device()}
|
config_kwargs["device_map"] = {"": get_current_device()}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() or getattr(config, "model_type", None) == "chatglm":
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
else:
|
|
||||||
low_cpu_mem_usage = True
|
|
||||||
|
|
||||||
# Load pre-trained models (without valuehead)
|
# Load pre-trained models (without valuehead)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_to_load,
|
model_to_load,
|
||||||
config=config,
|
config=config,
|
||||||
torch_dtype=model_args.compute_dtype,
|
torch_dtype=model_args.compute_dtype,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user