From 9deece1d50ecfabec3ba5ab7d80cdd2f132384a8 Mon Sep 17 00:00:00 2001 From: Yu Shi Jie Date: Tue, 1 Apr 2025 04:06:48 -0400 Subject: [PATCH] [model] fix use_cache patching for gemma3 multimodal (#7500) --- src/llamafactory/model/patcher.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 863197e5..a1142a30 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -107,6 +107,10 @@ def patch_config( setattr(config, "use_cache", True) logger.info_rank0("Using KV cache for faster generation.") + if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache: + text_config = config.text_config + setattr(text_config, "use_cache", False) + if getattr(config, "model_type", None) == "qwen": setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: