mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 08:08:09 +08:00
fix quant infer and qwen2moe
Former-commit-id: b75d16767f35c36e2cf2aaab8a3844135085bccf
This commit is contained in:
parent
6030a4a720
commit
566d71b7a9
@ -109,9 +109,6 @@ def load_model(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
model.eval()
|
model.eval()
|
||||||
for param in model.parameters():
|
|
||||||
if param.device.type == "cuda":
|
|
||||||
param.data = param.data.to(model_args.compute_dtype)
|
|
||||||
else:
|
else:
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
|
@ -316,6 +316,9 @@ def patch_config(
|
|||||||
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
|
if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn:
|
||||||
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
|
setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flashattn
|
||||||
|
|
||||||
|
if getattr(config, "model_type", None) == "qwen2_moe" and is_trainable:
|
||||||
|
setattr(config, "output_router_logits", True)
|
||||||
|
|
||||||
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
init_kwargs["torch_dtype"] = model_args.compute_dtype
|
||||||
if not is_deepspeed_zero3_enabled():
|
if not is_deepspeed_zero3_enabled():
|
||||||
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage
|
||||||
|
Loading…
x
Reference in New Issue
Block a user