Former-commit-id: 1ebc890a5ff7b034c112bc9cf5cd8a6936613572
This commit is contained in:
hiyouga 2024-05-19 17:07:57 +08:00
parent 62ddab4b3a
commit df4aec7e72

View File

@ -69,7 +69,7 @@ def autocast_projector_dtype(
) -> "torch.Tensor": ) -> "torch.Tensor":
return output.to(model_args.compute_dtype) return output.to(model_args.compute_dtype)
if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None): if hasattr(model, mm_projector_name) and getattr(model, "quantization_method", None):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype)) logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name) mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook) mm_projector.register_forward_hook(_mm_projector_forward_post_hook)