From 29417d1f9b181f907f7e3729791a43554f3bbf56 Mon Sep 17 00:00:00 2001 From: RWL Date: Fri, 22 Oct 2021 04:50:53 -0700 Subject: [PATCH] NaN (divide by zero) fix for issue #561 and #790 (#891) Summary: https://github.com/facebookresearch/pytorch3d/issues/561 https://github.com/facebookresearch/pytorch3d/issues/790 Divide by zero fix (NaN fix). When perspective_correct=True, BarycentricPerspectiveCorrectionForward and BarycentricPerspectiveCorrectionBackward in ../csrc/utils/geometry_utils.cuh are called. The denominator (denom) values should not be allowed to go to zero. I'm able to resolve this issue locally with this PR and submit it for the team's review. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/891 Reviewed By: patricklabatut Differential Revision: D31829695 Pulled By: bottler fbshipit-source-id: a3517b8362f6e60d48c35731258d8ce261b1d912 --- pytorch3d/csrc/utils/geometry_utils.cuh | 4 ++-- pytorch3d/csrc/utils/geometry_utils.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch3d/csrc/utils/geometry_utils.cuh b/pytorch3d/csrc/utils/geometry_utils.cuh index 53ff4f5a..9e2979ac 100644 --- a/pytorch3d/csrc/utils/geometry_utils.cuh +++ b/pytorch3d/csrc/utils/geometry_utils.cuh @@ -177,7 +177,7 @@ __device__ inline float3 BarycentricPerspectiveCorrectionForward( const float w0_top = bary.x * z1 * z2; const float w1_top = z0 * bary.y * z2; const float w2_top = z0 * z1 * bary.z; - const float denom = w0_top + w1_top + w2_top; + const float denom = fmaxf(w0_top + w1_top + w2_top, kEpsilon); const float w0 = w0_top / denom; const float w1 = w1_top / denom; const float w2 = w2_top / denom; @@ -208,7 +208,7 @@ BarycentricPerspectiveCorrectionBackward( const float w0_top = bary.x * z1 * z2; const float w1_top = z0 * bary.y * z2; const float w2_top = z0 * z1 * bary.z; - const float denom = w0_top + w1_top + w2_top; + const float denom = fmaxf(w0_top + w1_top + w2_top, kEpsilon); // Now do backward pass const float grad_denom_top = diff --git a/pytorch3d/csrc/utils/geometry_utils.h b/pytorch3d/csrc/utils/geometry_utils.h index c8b57f53..407849d8 100644 --- a/pytorch3d/csrc/utils/geometry_utils.h +++ b/pytorch3d/csrc/utils/geometry_utils.h @@ -198,7 +198,7 @@ inline vec3 BarycentricPerspectiveCorrectionForward( const T w0_top = bary.x * z1 * z2; const T w1_top = bary.y * z0 * z2; const T w2_top = bary.z * z0 * z1; - const T denom = w0_top + w1_top + w2_top; + const T denom = std::max(w0_top + w1_top + w2_top, kEpsilon); const T w0 = w0_top / denom; const T w1 = w1_top / denom; const T w2 = w2_top / denom; @@ -229,7 +229,7 @@ inline std::tuple, T, T, T> BarycentricPerspectiveCorrectionBackward( const T w0_top = bary.x * z1 * z2; const T w1_top = bary.y * z0 * z2; const T w2_top = bary.z * z0 * z1; - const T denom = w0_top + w1_top + w2_top; + const T denom = std::max(w0_top + w1_top + w2_top, kEpsilon); // Now do backward pass const T grad_denom_top =