[model] fix use_cache patching for gemma3 multimodal (#7500)

This commit is contained in:
Yu Shi Jie 2025-04-01 04:06:48 -04:00 committed by GitHub
parent f06a74ad4e
commit 9deece1d50
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -107,6 +107,10 @@ def patch_config(
setattr(config, "use_cache", True) setattr(config, "use_cache", True)
logger.info_rank0("Using KV cache for faster generation.") logger.info_rank0("Using KV cache for faster generation.")
if config.architectures[0] == "Gemma3ForConditionalGeneration" and not model_args.use_cache:
text_config = config.text_config
setattr(text_config, "use_cache", False)
if getattr(config, "model_type", None) == "qwen": if getattr(config, "model_type", None) == "qwen":
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: