mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-11 23:06:04 +08:00
fix alpha compositing
Summary: Fix division by zero when alpha is 1.0 In this case, the nominator is already 0 and we need to make sure division with 0 does not occur which would produce nans Reviewed By: nikhilaravi Differential Revision: D21650478 fbshipit-source-id: bc457105b3050fef1c8bd4e58e7d6d15c0c81ffd
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f2d1d2db69
commit
d689baac5e
@@ -11,6 +11,8 @@
|
||||
#include <stdio.h>
|
||||
#include <vector>
|
||||
|
||||
__constant__ const float kEpsilon = 1e-9;
|
||||
|
||||
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
|
||||
// Currently, support is for floats only.
|
||||
__global__ void alphaCompositeCudaForwardKernel(
|
||||
@@ -126,7 +128,7 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
||||
atomicAdd(
|
||||
&grad_alphas[batch][t][j][i],
|
||||
-grad_outputs[batch][ch][j][i] * features[ch][n_idx] * cum_alpha *
|
||||
alpha / (1 - alpha_tvalue));
|
||||
alpha / (1 - alpha_tvalue + kEpsilon));
|
||||
}
|
||||
|
||||
cum_alpha = cum_alpha * (1 - alphas[batch][k][j][i]);
|
||||
|
||||
Reference in New Issue
Block a user