diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 863197e5..a1142a30 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -107,6 +107,10 @@ def patch_config( setattr(config, "use_cache", True) logger.info_rank0("Using KV cache for faster generation.") + if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache: + text_config = config.text_config + setattr(text_config, "use_cache", False) + if getattr(config, "model_type", None) == "qwen": setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: