mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 03:02:51 +08:00
[model] fix use_cache patching for gemma3 multimodal (#7500)
This commit is contained in:
parent
f06a74ad4e
commit
9deece1d50
@ -107,6 +107,10 @@ def patch_config(
|
||||
setattr(config, "use_cache", True)
|
||||
logger.info_rank0("Using KV cache for faster generation.")
|
||||
|
||||
if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache:
|
||||
text_config = config.text_config
|
||||
setattr(text_config, "use_cache", False)
|
||||
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user