fix alpha compositing

Summary:
Fix division by zero when alpha is 1.0
In this case, the nominator is already 0 and we need to make sure division with 0 does not occur which would produce nans

Reviewed By: nikhilaravi

Differential Revision: D21650478

fbshipit-source-id: bc457105b3050fef1c8bd4e58e7d6d15c0c81ffd
This commit is contained in:
Georgia Gkioxari 2020-05-20 09:25:44 -07:00 committed by Facebook GitHub Bot
parent f2d1d2db69
commit d689baac5e
2 changed files with 8 additions and 2 deletions

View File

@ -11,6 +11,8 @@
#include <stdio.h>
#include <vector>
__constant__ const float kEpsilon = 1e-9;
// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
// Currently, support is for floats only.
__global__ void alphaCompositeCudaForwardKernel(
@ -126,7 +128,7 @@ __global__ void alphaCompositeCudaBackwardKernel(
atomicAdd(
&grad_alphas[batch][t][j][i],
-grad_outputs[batch][ch][j][i] * features[ch][n_idx] * cum_alpha *
alpha / (1 - alpha_tvalue));
alpha / (1 - alpha_tvalue + kEpsilon));
}
cum_alpha = cum_alpha * (1 - alphas[batch][k][j][i]);

View File

@ -5,6 +5,9 @@
#include <cmath>
#include <vector>
// Epsilon float
const float kEps = 1e-9;
torch::Tensor alphaCompositeCpuForward(
const torch::Tensor& features,
const torch::Tensor& alphas,
@ -101,7 +104,8 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeCpuBackward(
}
float alpha_tvalue = alphas_a[b][t][j][i];
grad_alphas_a[b][t][j][i] -= grad_outputs_a[b][c][j][i] *
features_a[c][n_idx] * cum_alpha * alpha / (1 - alpha_tvalue);
features_a[c][n_idx] * cum_alpha * alpha /
(1 - alpha_tvalue + kEps);
}
cum_alpha = cum_alpha * (1 - alpha);