mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
CUDA/C++ Rasterizer updates to handle clipped faces
Summary: - Updated the C++/CUDA mesh rasterization kernels to handle the clipped faces. In particular this required careful handling of the distance calculation for faces which are cut into a quadrilateral by the image plane and then split into two sub triangles i.e. both sub triangles can't be part of the top K faces. - Updated `rasterize_meshes.py` to use the utils functions to clip the meshes and convert the fragments back to in terms of the unclipped mesh - Added end to end tests Reviewed By: jcjohnson Differential Revision: D26169685 fbshipit-source-id: d64cd0d656109b965f44a35c301b7c81f451cfa0
This commit is contained in:
parent
838b73d3b6
commit
340662e98e
@ -17,8 +17,8 @@ namespace {
|
||||
// A structure for holding details about a pixel.
|
||||
struct Pixel {
|
||||
float z;
|
||||
int64_t idx;
|
||||
float dist;
|
||||
int64_t idx; // idx of face
|
||||
float dist; // abs distance of pixel to face
|
||||
float3 bary;
|
||||
};
|
||||
|
||||
@ -111,6 +111,7 @@ __device__ bool CheckPointOutsideBoundingBox(
|
||||
template <typename FaceQ>
|
||||
__device__ void CheckPixelInsideFace(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int64_t* clipped_faces_neighbor_idx, // (F,)
|
||||
const int face_idx,
|
||||
int& q_size,
|
||||
float& q_max_z,
|
||||
@ -173,32 +174,72 @@ __device__ void CheckPixelInsideFace(
|
||||
// face.
|
||||
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
|
||||
const float signed_dist = inside ? -dist : dist;
|
||||
|
||||
// Check if pixel is outside blur region
|
||||
if (!inside && dist >= blur_radius) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (q_size < K) {
|
||||
// Just insert it.
|
||||
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = q_size;
|
||||
// Handle the case where a face (f) partially behind the image plane is
|
||||
// clipped to a quadrilateral and then split into two faces (t1, t2). In this
|
||||
// case we:
|
||||
// 1. Find the index of the neighboring face (e.g. for t1 need index of t2)
|
||||
// 2. Check if the neighboring face (t2) is already in the top K faces
|
||||
// 3. If yes, compare the distance of the pixel to t1 with the distance to t2.
|
||||
// 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K faces.
|
||||
const int neighbor_idx = clipped_faces_neighbor_idx[face_idx];
|
||||
int neighbor_idx_top_k = -1;
|
||||
|
||||
// Check if neighboring face is already in the top K.
|
||||
// -1 is the fill value in clipped_faces_neighbor_idx
|
||||
if (neighbor_idx != -1) {
|
||||
// Only need to loop until q_size.
|
||||
for (int i = 0; i < q_size; i++) {
|
||||
if (q[i].idx == neighbor_idx) {
|
||||
neighbor_idx_top_k = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
q_size++;
|
||||
} else if (pz < q_max_z) {
|
||||
// Overwrite the old max, and find the new max.
|
||||
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
q_max_z = pz;
|
||||
for (int i = 0; i < K; i++) {
|
||||
if (q[i].z > q_max_z) {
|
||||
q_max_z = q[i].z;
|
||||
q_max_idx = i;
|
||||
}
|
||||
// If neighbor idx is not -1 then it is in the top K struct.
|
||||
if (neighbor_idx_top_k != -1) {
|
||||
// If dist of current face is less than neighbor then overwrite the
|
||||
// neighbor face values in the top K struct.
|
||||
float neighbor_dist = abs(q[neighbor_idx_top_k].dist);
|
||||
if (dist < neighbor_dist) {
|
||||
// Overwrite the neighbor face values
|
||||
q[neighbor_idx_top_k] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
|
||||
// If pz > q_max then overwrite the max values and index of the max.
|
||||
// q_size stays the same.
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = neighbor_idx_top_k;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle as a normal face
|
||||
if (q_size < K) {
|
||||
// Just insert it.
|
||||
q[q_size] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
if (pz > q_max_z) {
|
||||
q_max_z = pz;
|
||||
q_max_idx = q_size;
|
||||
}
|
||||
q_size++;
|
||||
} else if (pz < q_max_z) {
|
||||
// Overwrite the old max, and find the new max.
|
||||
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary_clip};
|
||||
q_max_z = pz;
|
||||
for (int i = 0; i < K; i++) {
|
||||
if (q[i].z > q_max_z) {
|
||||
q_max_z = q[i].z;
|
||||
q_max_idx = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// ****************************************************************************
|
||||
@ -208,6 +249,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
const float* face_verts,
|
||||
const int64_t* mesh_to_face_first_idx,
|
||||
const int64_t* num_faces_per_mesh,
|
||||
const int64_t* clipped_faces_neighbor_idx,
|
||||
const float blur_radius,
|
||||
const bool perspective_correct,
|
||||
const bool clip_barycentric_coords,
|
||||
@ -265,6 +307,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
|
||||
CheckPixelInsideFace(
|
||||
face_verts,
|
||||
clipped_faces_neighbor_idx,
|
||||
f,
|
||||
q_size,
|
||||
q_max_z,
|
||||
@ -298,6 +341,7 @@ RasterizeMeshesNaiveCuda(
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_faces_packed_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
const at::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
@ -313,6 +357,10 @@ RasterizeMeshesNaiveCuda(
|
||||
num_faces_per_mesh.size(0) == mesh_to_faces_packed_first_idx.size(0),
|
||||
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
|
||||
|
||||
TORCH_CHECK(
|
||||
clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
|
||||
"clipped_faces_neighbor_idx must have save size first dimension as face_verts");
|
||||
|
||||
if (num_closest > kMaxPointsPerPixel) {
|
||||
std::stringstream ss;
|
||||
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
|
||||
@ -323,11 +371,16 @@ RasterizeMeshesNaiveCuda(
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
mesh_to_faces_packed_first_idx_t{
|
||||
mesh_to_faces_packed_first_idx, "mesh_to_faces_packed_first_idx", 2},
|
||||
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3};
|
||||
num_faces_per_mesh_t{num_faces_per_mesh, "num_faces_per_mesh", 3},
|
||||
clipped_faces_neighbor_idx_t{
|
||||
clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 4};
|
||||
at::CheckedFrom c = "RasterizeMeshesNaiveCuda";
|
||||
at::checkAllSameGPU(
|
||||
c,
|
||||
{face_verts_t, mesh_to_faces_packed_first_idx_t, num_faces_per_mesh_t});
|
||||
{face_verts_t,
|
||||
mesh_to_faces_packed_first_idx_t,
|
||||
num_faces_per_mesh_t,
|
||||
clipped_faces_neighbor_idx_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
@ -358,6 +411,7 @@ RasterizeMeshesNaiveCuda(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
mesh_to_faces_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
|
||||
clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
|
||||
blur_radius,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
@ -800,6 +854,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
||||
__global__ void RasterizeMeshesFineCudaKernel(
|
||||
const float* face_verts, // (F, 3, 3)
|
||||
const int32_t* bin_faces, // (N, BH, BW, T)
|
||||
const int64_t* clipped_faces_neighbor_idx, // (F,)
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
const bool perspective_correct,
|
||||
@ -858,6 +913,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
int q_size = 0;
|
||||
float q_max_z = -1000;
|
||||
int q_max_idx = -1;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
|
||||
if (f < 0) {
|
||||
@ -867,6 +923,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
// update q, q_size, q_max_z and q_max_idx in place.
|
||||
CheckPixelInsideFace(
|
||||
face_verts,
|
||||
clipped_faces_neighbor_idx,
|
||||
f,
|
||||
q_size,
|
||||
q_max_z,
|
||||
@ -906,6 +963,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
RasterizeMeshesFineCuda(
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& bin_faces,
|
||||
const at::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
@ -918,12 +976,18 @@ RasterizeMeshesFineCuda(
|
||||
face_verts.size(2) == 3,
|
||||
"face_verts must have dimensions (num_faces, 3, 3)");
|
||||
TORCH_CHECK(bin_faces.ndimension() == 4, "bin_faces must have 4 dimensions");
|
||||
TORCH_CHECK(
|
||||
clipped_faces_neighbor_idx.size(0) == face_verts.size(0),
|
||||
"clipped_faces_neighbor_idx must have the same first dimension as face_verts");
|
||||
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg face_verts_t{face_verts, "face_verts", 1},
|
||||
bin_faces_t{bin_faces, "bin_faces", 2};
|
||||
bin_faces_t{bin_faces, "bin_faces", 2},
|
||||
clipped_faces_neighbor_idx_t{
|
||||
clipped_faces_neighbor_idx, "clipped_faces_neighbor_idx", 3};
|
||||
at::CheckedFrom c = "RasterizeMeshesFineCuda";
|
||||
at::checkAllSameGPU(c, {face_verts_t, bin_faces_t});
|
||||
at::checkAllSameGPU(
|
||||
c, {face_verts_t, bin_faces_t, clipped_faces_neighbor_idx_t});
|
||||
|
||||
// Set the device for the kernel launch based on the device of the input
|
||||
at::cuda::CUDAGuard device_guard(face_verts.device());
|
||||
@ -961,6 +1025,7 @@ RasterizeMeshesFineCuda(
|
||||
RasterizeMeshesFineCudaKernel<<<blocks, threads, 0, stream>>>(
|
||||
face_verts.contiguous().data_ptr<float>(),
|
||||
bin_faces.contiguous().data_ptr<int32_t>(),
|
||||
clipped_faces_neighbor_idx.contiguous().data_ptr<int64_t>(),
|
||||
blur_radius,
|
||||
bin_size,
|
||||
perspective_correct,
|
||||
|
@ -15,6 +15,7 @@ RasterizeMeshesNaiveCpu(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
const torch::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
@ -28,6 +29,7 @@ RasterizeMeshesNaiveCuda(
|
||||
const at::Tensor& face_verts,
|
||||
const at::Tensor& mesh_to_face_first_idx,
|
||||
const at::Tensor& num_faces_per_mesh,
|
||||
const torch::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int num_closest,
|
||||
@ -48,6 +50,12 @@ RasterizeMeshesNaiveCuda(
|
||||
// the batch where N is the batch size.
|
||||
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
|
||||
// for each mesh in the batch.
|
||||
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
|
||||
// index of the neighboring face for each face which was clipped to a
|
||||
// quadrilateral and then divided into two triangles.
|
||||
// e.g. for a face f partially behind the image plane which is split into
|
||||
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
|
||||
// Faces which are not clipped and subdivided are set to -1.
|
||||
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||
// image to be rasterized.
|
||||
// blur_radius: float distance in NDC coordinates uses to expand the face
|
||||
@ -90,6 +98,7 @@ RasterizeMeshesNaive(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
const torch::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
@ -106,6 +115,7 @@ RasterizeMeshesNaive(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
num_faces_per_mesh,
|
||||
clipped_faces_neighbor_idx,
|
||||
image_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
@ -120,6 +130,7 @@ RasterizeMeshesNaive(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
num_faces_per_mesh,
|
||||
clipped_faces_neighbor_idx,
|
||||
image_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
@ -306,6 +317,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizeMeshesFineCuda(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& bin_faces,
|
||||
const torch::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
@ -322,6 +334,12 @@ RasterizeMeshesFineCuda(
|
||||
// in NDC coordinates in the range [-1, 1].
|
||||
// bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces
|
||||
// that fall into each bin (output from coarse rasterization).
|
||||
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
|
||||
// index of the neighboring face for each face which was clipped to a
|
||||
// quadrilateral and then divided into two triangles.
|
||||
// e.g. for a face f partially behind the image plane which is split into
|
||||
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
|
||||
// Faces which are not clipped and subdivided are set to -1.
|
||||
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||
// image to be rasterized.
|
||||
// blur_radius: float distance in NDC coordinates uses to expand the face
|
||||
@ -364,6 +382,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizeMeshesFine(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& bin_faces,
|
||||
const torch::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int bin_size,
|
||||
@ -378,6 +397,7 @@ RasterizeMeshesFine(
|
||||
return RasterizeMeshesFineCuda(
|
||||
face_verts,
|
||||
bin_faces,
|
||||
clipped_faces_neighbor_idx,
|
||||
image_size,
|
||||
blur_radius,
|
||||
bin_size,
|
||||
@ -411,6 +431,12 @@ RasterizeMeshesFine(
|
||||
// the batch where N is the batch size.
|
||||
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
|
||||
// for each mesh in the batch.
|
||||
// clipped_faces_neighbor_idx: LongTensor of shape (F,) giving the
|
||||
// index of the neighboring face for each face which was clipped to a
|
||||
// quadrilateral and then divided into two triangles.
|
||||
// e.g. for a face f partially behind the image plane which is split into
|
||||
// two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx
|
||||
// Faces which are not clipped and subdivided are set to -1.
|
||||
// image_size: Tuple (H, W) giving the size in pixels of the output
|
||||
// image to be rasterized.
|
||||
// blur_radius: float distance in NDC coordinates uses to expand the face
|
||||
@ -456,6 +482,7 @@ RasterizeMeshes(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
const torch::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
@ -477,6 +504,7 @@ RasterizeMeshes(
|
||||
return RasterizeMeshesFine(
|
||||
face_verts,
|
||||
bin_faces,
|
||||
clipped_faces_neighbor_idx,
|
||||
image_size,
|
||||
blur_radius,
|
||||
bin_size,
|
||||
@ -490,6 +518,7 @@ RasterizeMeshes(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
num_faces_per_mesh,
|
||||
clipped_faces_neighbor_idx,
|
||||
image_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
|
@ -99,11 +99,31 @@ auto ComputeFaceAreas(const torch::Tensor& face_verts) {
|
||||
return face_areas;
|
||||
}
|
||||
|
||||
// Helper function to use with std::find_if to find the index of any
|
||||
// values in the top k struct which match a given idx.
|
||||
struct IsNeighbor {
|
||||
IsNeighbor(int neighbor_idx) {
|
||||
this->neighbor_idx = neighbor_idx;
|
||||
}
|
||||
bool operator()(std::tuple<float, int, float, float, float, float> elem) {
|
||||
return (std::get<1>(elem) == neighbor_idx);
|
||||
}
|
||||
int neighbor_idx;
|
||||
};
|
||||
|
||||
// Function to sort based on the z distance in the top K queue
|
||||
bool SortTopKByZdist(
|
||||
std::tuple<float, int, float, float, float, float> a,
|
||||
std::tuple<float, int, float, float, float, float> b) {
|
||||
return std::get<0>(a) < std::get<0>(b);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||
RasterizeMeshesNaiveCpu(
|
||||
const torch::Tensor& face_verts,
|
||||
const torch::Tensor& mesh_to_face_first_idx,
|
||||
const torch::Tensor& num_faces_per_mesh,
|
||||
const torch::Tensor& clipped_faces_neighbor_idx,
|
||||
const std::tuple<int, int> image_size,
|
||||
const float blur_radius,
|
||||
const int faces_per_pixel,
|
||||
@ -139,6 +159,7 @@ RasterizeMeshesNaiveCpu(
|
||||
auto zbuf_a = zbuf.accessor<float, 4>();
|
||||
auto pix_dists_a = pix_dists.accessor<float, 4>();
|
||||
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
|
||||
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
|
||||
|
||||
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
|
||||
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
|
||||
@ -168,10 +189,11 @@ RasterizeMeshesNaiveCpu(
|
||||
|
||||
// X coordinate of the left of the pixel.
|
||||
const float xf = PixToNonSquareNdc(xidx, W, H);
|
||||
// Use a priority queue to hold values:
|
||||
|
||||
// Use a deque to hold values:
|
||||
// (z, idx, r, bary.x, bary.y. bary.z)
|
||||
std::priority_queue<std::tuple<float, int, float, float, float, float>>
|
||||
q;
|
||||
// Sort the deque as needed to mimic a priority queue.
|
||||
std::deque<std::tuple<float, int, float, float, float, float>> q;
|
||||
|
||||
// Loop through the faces in the mesh.
|
||||
for (int f = face_start_idx; f < face_stop_idx; ++f) {
|
||||
@ -240,15 +262,58 @@ RasterizeMeshesNaiveCpu(
|
||||
if (!inside && dist >= blur_radius) {
|
||||
continue;
|
||||
}
|
||||
// The current pixel lies inside the current face.
|
||||
q.emplace(pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
|
||||
|
||||
// Handle the case where a face (f) partially behind the image plane
|
||||
// is clipped to a quadrilateral and then split into two faces (t1,
|
||||
// t2). In this case we:
|
||||
// 1. Find the index of the neighbor (e.g. for t1 need index of t2)
|
||||
// 2. Check if the neighbor (t2) is already in the top K faces
|
||||
// 3. If yes, compare the distance of the pixel to t1 with the
|
||||
// distance to t2.
|
||||
// 4. If dist_t1 < dist_t2, overwrite the values for t2 in the top K
|
||||
// faces.
|
||||
const int neighbor_idx = neighbor_idx_a[f];
|
||||
int idx_top_k = -1;
|
||||
|
||||
// Check if neighboring face is already in the top K.
|
||||
if (neighbor_idx != -1) {
|
||||
const auto it =
|
||||
std::find_if(q.begin(), q.end(), IsNeighbor(neighbor_idx));
|
||||
// Get the index of the element from the iterator
|
||||
idx_top_k = (it != q.end()) ? it - q.begin() : idx_top_k;
|
||||
}
|
||||
|
||||
// If idx_top_k idx is not -1 then it is in the top K struct.
|
||||
if (idx_top_k != -1) {
|
||||
// If dist of current face is less than neighbor, overwrite
|
||||
// the neighbor face values in the top K struct.
|
||||
const auto neighbor = q[idx_top_k];
|
||||
const float dist_neighbor = std::abs(std::get<2>(neighbor));
|
||||
if (dist < dist_neighbor) {
|
||||
// Overwrite the neighbor face values.
|
||||
q[idx_top_k] = {
|
||||
pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z};
|
||||
}
|
||||
} else {
|
||||
// Handle as a normal face.
|
||||
// The current pixel lies inside the current face.
|
||||
// Add at the end of the deque.
|
||||
q.emplace_back(
|
||||
pz, f, signed_dist, bary_clip.x, bary_clip.y, bary_clip.z);
|
||||
}
|
||||
|
||||
// Sort the deque inplace based on the z distance
|
||||
// to mimic using a priority queue.
|
||||
std::sort(q.begin(), q.end(), SortTopKByZdist);
|
||||
if (static_cast<int>(q.size()) > K) {
|
||||
q.pop();
|
||||
// remove the last value
|
||||
q.pop_back();
|
||||
}
|
||||
}
|
||||
while (!q.empty()) {
|
||||
auto t = q.top();
|
||||
q.pop();
|
||||
// Loop through and add values to the output tensors
|
||||
auto t = q.back();
|
||||
q.pop_back();
|
||||
const int i = q.size();
|
||||
zbuf_a[n][yi][xi][i] = std::get<0>(t);
|
||||
face_idxs_a[n][yi][xi][i] = std::get<1>(t);
|
||||
|
@ -9,6 +9,12 @@ import torch
|
||||
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
|
||||
from pytorch3d import _C
|
||||
|
||||
from .clip import (
|
||||
ClipFrustum,
|
||||
clip_faces,
|
||||
convert_clipped_rasterization_to_original_faces,
|
||||
)
|
||||
|
||||
|
||||
# TODO make the epsilon user configurable
|
||||
kEpsilon = 1e-8
|
||||
@ -28,6 +34,8 @@ def rasterize_meshes(
|
||||
perspective_correct: bool = False,
|
||||
clip_barycentric_coords: bool = False,
|
||||
cull_backfaces: bool = False,
|
||||
z_clip_value: Optional[float] = None,
|
||||
cull_to_frustum: bool = False,
|
||||
):
|
||||
"""
|
||||
Rasterize a batch of meshes given the shape of the desired output image.
|
||||
@ -67,7 +75,8 @@ def rasterize_meshes(
|
||||
will be raised. This should not affect the output values, but can affect
|
||||
the memory usage in the forward pass.
|
||||
perspective_correct: Bool, Whether to apply perspective correction when computing
|
||||
barycentric coordinates for pixels.
|
||||
barycentric coordinates for pixels. This should be set to True if a perspective
|
||||
camera is used.
|
||||
cull_backfaces: Bool, Whether to only rasterize mesh faces which are
|
||||
visible to the camera. This assumes that vertices of
|
||||
front-facing triangles are ordered in an anti-clockwise
|
||||
@ -76,6 +85,15 @@ def rasterize_meshes(
|
||||
direction. NOTE: This will only work if the mesh faces are
|
||||
consistently defined with counter-clockwise ordering when
|
||||
viewed from the outside.
|
||||
z_clip_value: if not None, then triangles will be clipped (and possibly
|
||||
subdivided into smaller triangles) such that z >= z_clip_value.
|
||||
This avoids camera projections that go to infinity as z->0.
|
||||
Default is None as clipping affects rasterization speed and
|
||||
should only be turned on if explicitly needed.
|
||||
See clip.py for all the extra computation that is required.
|
||||
cull_to_frustum: if True, triangles outside the view frustum will be culled.
|
||||
Culling involves removing all faces which fall outside view frustum.
|
||||
Default is False so that it is turned on only when needed.
|
||||
|
||||
Returns:
|
||||
4-element tuple containing
|
||||
@ -141,6 +159,42 @@ def rasterize_meshes(
|
||||
im_size = (image_size, image_size)
|
||||
max_image_size = image_size
|
||||
|
||||
clipped_faces_neighbor_idx = None
|
||||
|
||||
if z_clip_value is not None or cull_to_frustum:
|
||||
# Cull faces outside the view frustum, and clip faces that are partially
|
||||
# behind the camera into the portion of the triangle in front of the
|
||||
# camera. This may change the number of faces
|
||||
frustum = ClipFrustum(
|
||||
left=-1,
|
||||
right=1,
|
||||
top=-1,
|
||||
bottom=1,
|
||||
perspective_correct=perspective_correct,
|
||||
z_clip_value=z_clip_value,
|
||||
cull=cull_to_frustum,
|
||||
)
|
||||
clipped_faces = clip_faces(
|
||||
face_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum=frustum
|
||||
)
|
||||
face_verts = clipped_faces.face_verts
|
||||
mesh_to_face_first_idx = clipped_faces.mesh_to_face_first_idx
|
||||
num_faces_per_mesh = clipped_faces.num_faces_per_mesh
|
||||
|
||||
# For case 4 clipped triangles (where a big triangle is split in two smaller triangles),
|
||||
# need the index of the neighboring clipped triangle as only one can be in
|
||||
# in the top K closest faces in the rasterization step.
|
||||
clipped_faces_neighbor_idx = clipped_faces.clipped_faces_neighbor_idx
|
||||
|
||||
if clipped_faces_neighbor_idx is None:
|
||||
# Set to the default which is all -1s.
|
||||
clipped_faces_neighbor_idx = torch.full(
|
||||
size=(face_verts.shape[0],),
|
||||
fill_value=-1,
|
||||
device=meshes.device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
# TODO: Choose naive vs coarse-to-fine based on mesh size and image size.
|
||||
if bin_size is None:
|
||||
if not verts_packed.is_cuda:
|
||||
@ -172,10 +226,11 @@ def rasterize_meshes(
|
||||
max_faces_per_bin = int(max(10000, meshes._F / 5))
|
||||
|
||||
# pyre-fixme[16]: `_RasterizeFaceVerts` has no attribute `apply`.
|
||||
return _RasterizeFaceVerts.apply(
|
||||
pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
num_faces_per_mesh,
|
||||
clipped_faces_neighbor_idx,
|
||||
im_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
@ -186,6 +241,17 @@ def rasterize_meshes(
|
||||
cull_backfaces,
|
||||
)
|
||||
|
||||
if z_clip_value is not None or cull_to_frustum:
|
||||
# If faces were clipped, map the rasterization result to be in terms of the
|
||||
# original unclipped faces. This may involve converting barycentric
|
||||
# coordinates
|
||||
outputs = convert_clipped_rasterization_to_original_faces(
|
||||
pix_to_face, barycentric_coords, clipped_faces
|
||||
)
|
||||
pix_to_face, barycentric_coords = outputs
|
||||
|
||||
return pix_to_face, zbuf, barycentric_coords, dists
|
||||
|
||||
|
||||
class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
"""
|
||||
@ -216,9 +282,10 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
|
||||
def forward(
|
||||
ctx,
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
num_faces_per_mesh,
|
||||
face_verts: torch.Tensor,
|
||||
mesh_to_face_first_idx: torch.Tensor,
|
||||
num_faces_per_mesh: torch.Tensor,
|
||||
clipped_faces_neighbor_idx: torch.Tensor,
|
||||
image_size: Union[List[int], Tuple[int, int]] = (256, 256),
|
||||
blur_radius: float = 0.01,
|
||||
faces_per_pixel: int = 0,
|
||||
@ -227,12 +294,15 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
perspective_correct: bool = False,
|
||||
clip_barycentric_coords: bool = False,
|
||||
cull_backfaces: bool = False,
|
||||
z_clip_value: Optional[float] = None,
|
||||
cull_to_frustum: bool = True,
|
||||
):
|
||||
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
|
||||
pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
num_faces_per_mesh,
|
||||
clipped_faces_neighbor_idx,
|
||||
image_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
@ -242,6 +312,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(face_verts, pix_to_face)
|
||||
ctx.mark_non_differentiable(pix_to_face)
|
||||
ctx.perspective_correct = perspective_correct
|
||||
@ -253,6 +324,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
grad_face_verts = None
|
||||
grad_mesh_to_face_first_idx = None
|
||||
grad_num_faces_per_mesh = None
|
||||
grad_clipped_faces_neighbor_idx = None
|
||||
grad_image_size = None
|
||||
grad_radius = None
|
||||
grad_faces_per_pixel = None
|
||||
@ -275,6 +347,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
||||
grad_face_verts,
|
||||
grad_mesh_to_face_first_idx,
|
||||
grad_num_faces_per_mesh,
|
||||
grad_clipped_faces_neighbor_idx,
|
||||
grad_image_size,
|
||||
grad_radius,
|
||||
grad_faces_per_pixel,
|
||||
@ -339,6 +412,9 @@ def rasterize_meshes_python(
|
||||
perspective_correct: bool = False,
|
||||
clip_barycentric_coords: bool = False,
|
||||
cull_backfaces: bool = False,
|
||||
z_clip_value: Optional[float] = None,
|
||||
cull_to_frustum: bool = True,
|
||||
clipped_faces_neighbor_idx: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Naive PyTorch implementation of mesh rasterization with the same inputs and
|
||||
@ -359,6 +435,26 @@ def rasterize_meshes_python(
|
||||
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
||||
|
||||
if z_clip_value is not None or cull_to_frustum:
|
||||
# Cull faces outside the view frustum, and clip faces that are partially
|
||||
# behind the camera into the portion of the triangle in front of the
|
||||
# camera. This may change the number of faces
|
||||
frustum = ClipFrustum(
|
||||
left=-1,
|
||||
right=1,
|
||||
top=-1,
|
||||
bottom=1,
|
||||
perspective_correct=perspective_correct,
|
||||
z_clip_value=z_clip_value,
|
||||
cull=cull_to_frustum,
|
||||
)
|
||||
clipped_faces = clip_faces(
|
||||
faces_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum=frustum
|
||||
)
|
||||
faces_verts = clipped_faces.face_verts
|
||||
mesh_to_face_first_idx = clipped_faces.mesh_to_face_first_idx
|
||||
num_faces_per_mesh = clipped_faces.num_faces_per_mesh
|
||||
|
||||
# Intialize output tensors.
|
||||
face_idxs = torch.full(
|
||||
(N, H, W, K), fill_value=-1, dtype=torch.int64, device=device
|
||||
@ -468,26 +564,55 @@ def rasterize_meshes_python(
|
||||
# Points inside the triangle have negative distance.
|
||||
dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
|
||||
|
||||
signed_dist = dist * -1.0 if inside else dist
|
||||
|
||||
# Add an epsilon to prevent errors when comparing distance
|
||||
# to blur radius.
|
||||
if not inside and dist >= blur_radius:
|
||||
continue
|
||||
|
||||
top_k_points.append((pz, f, bary, signed_dist))
|
||||
# Handle the case where a face (f) partially behind the image plane is
|
||||
# clipped to a quadrilateral and then split into two faces (t1, t2).
|
||||
top_k_idx = -1
|
||||
if (
|
||||
clipped_faces_neighbor_idx is not None
|
||||
and clipped_faces_neighbor_idx[f] != -1
|
||||
):
|
||||
neighbor_idx = clipped_faces_neighbor_idx[f]
|
||||
# See if neighbor_idx is in top_k and find index
|
||||
top_k_idx = [
|
||||
i
|
||||
for i, val in enumerate(top_k_points)
|
||||
if val[1] == neighbor_idx
|
||||
]
|
||||
top_k_idx = top_k_idx[0] if len(top_k_idx) > 0 else -1
|
||||
|
||||
if top_k_idx != -1 and dist < top_k_points[top_k_idx][3]:
|
||||
# Overwrite the neighbor with current face info
|
||||
top_k_points[top_k_idx] = (pz, f, bary, dist, inside)
|
||||
else:
|
||||
# Handle as a normal face
|
||||
top_k_points.append((pz, f, bary, dist, inside))
|
||||
|
||||
top_k_points.sort()
|
||||
if len(top_k_points) > K:
|
||||
top_k_points = top_k_points[:K]
|
||||
|
||||
# Save to output tensors.
|
||||
for k, (pz, f, bary, dist) in enumerate(top_k_points):
|
||||
for k, (pz, f, bary, dist, inside) in enumerate(top_k_points):
|
||||
zbuf[n, yi, xi, k] = pz
|
||||
face_idxs[n, yi, xi, k] = f
|
||||
bary_coords[n, yi, xi, k, 0] = bary[0]
|
||||
bary_coords[n, yi, xi, k, 1] = bary[1]
|
||||
bary_coords[n, yi, xi, k, 2] = bary[2]
|
||||
pix_dists[n, yi, xi, k] = dist
|
||||
# Write the signed distance
|
||||
pix_dists[n, yi, xi, k] = -dist if inside else dist
|
||||
|
||||
if z_clip_value is not None or cull_to_frustum:
|
||||
# If faces were clipped, map the rasterization result to be in terms of the
|
||||
# original unclipped faces. This may involve converting barycentric
|
||||
# coordinates
|
||||
(face_idxs, bary_coords,) = convert_clipped_rasterization_to_original_faces(
|
||||
face_idxs, bary_coords, clipped_faces
|
||||
)
|
||||
|
||||
return face_idxs, zbuf, bary_coords, pix_dists
|
||||
|
||||
|
@ -27,6 +27,8 @@ class RasterizationSettings:
|
||||
"perspective_correct",
|
||||
"clip_barycentric_coords",
|
||||
"cull_backfaces",
|
||||
"z_clip_value",
|
||||
"cull_to_frustum",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@ -36,9 +38,13 @@ class RasterizationSettings:
|
||||
faces_per_pixel: int = 1,
|
||||
bin_size: Optional[int] = None,
|
||||
max_faces_per_bin: Optional[int] = None,
|
||||
perspective_correct: bool = False,
|
||||
# set perspective_correct = None so that the
|
||||
# value can be inferred correctly from the Camera type
|
||||
perspective_correct: Optional[bool] = None,
|
||||
clip_barycentric_coords: Optional[bool] = None,
|
||||
cull_backfaces: bool = False,
|
||||
z_clip_value: Optional[float] = None,
|
||||
cull_to_frustum: bool = False,
|
||||
):
|
||||
self.image_size = image_size
|
||||
self.blur_radius = blur_radius
|
||||
@ -48,6 +54,8 @@ class RasterizationSettings:
|
||||
self.perspective_correct = perspective_correct
|
||||
self.clip_barycentric_coords = clip_barycentric_coords
|
||||
self.cull_backfaces = cull_backfaces
|
||||
self.z_clip_value = z_clip_value
|
||||
self.cull_to_frustum = cull_to_frustum
|
||||
|
||||
|
||||
class MeshRasterizer(nn.Module):
|
||||
@ -139,12 +147,19 @@ class MeshRasterizer(nn.Module):
|
||||
if clip_barycentric_coords is None:
|
||||
clip_barycentric_coords = raster_settings.blur_radius > 0.0
|
||||
|
||||
# If not specified, infer perspective_correct from the camera
|
||||
# If not specified, infer perspective_correct and z_clip_value from the camera
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if raster_settings.perspective_correct is not None:
|
||||
perspective_correct = raster_settings.perspective_correct
|
||||
else:
|
||||
perspective_correct = cameras.is_perspective()
|
||||
if raster_settings.z_clip_value is not None:
|
||||
z_clip = raster_settings.z_clip_value
|
||||
else:
|
||||
znear = cameras.get_znear()
|
||||
if isinstance(znear, torch.Tensor):
|
||||
znear = znear.min().item()
|
||||
z_clip = None if not perspective_correct or znear is None else znear / 2
|
||||
|
||||
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
|
||||
meshes_screen,
|
||||
@ -153,9 +168,11 @@ class MeshRasterizer(nn.Module):
|
||||
faces_per_pixel=raster_settings.faces_per_pixel,
|
||||
bin_size=raster_settings.bin_size,
|
||||
max_faces_per_bin=raster_settings.max_faces_per_bin,
|
||||
perspective_correct=perspective_correct,
|
||||
clip_barycentric_coords=clip_barycentric_coords,
|
||||
perspective_correct=perspective_correct,
|
||||
cull_backfaces=raster_settings.cull_backfaces,
|
||||
z_clip_value=z_clip,
|
||||
cull_to_frustum=raster_settings.cull_to_frustum,
|
||||
)
|
||||
return Fragments(
|
||||
pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists
|
||||
|
@ -66,7 +66,7 @@ def bm_rasterize_meshes() -> None:
|
||||
# Square and non square cases
|
||||
image_size = [64, 128, 512, (512, 256), (256, 512)]
|
||||
blur = [1e-6]
|
||||
faces_per_pixel = [50]
|
||||
faces_per_pixel = [40]
|
||||
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
|
||||
|
||||
for case in test_cases:
|
||||
@ -87,6 +87,35 @@ def bm_rasterize_meshes() -> None:
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
# Test a subset of the cases with the
|
||||
# image plane intersecting the mesh.
|
||||
kwargs_list = []
|
||||
num_meshes = [8, 16]
|
||||
# Square and non square cases
|
||||
image_size = [64, 128, 512, (512, 256), (256, 512)]
|
||||
dist = [3, 0.8, 0.5]
|
||||
test_cases = product(num_meshes, dist, image_size)
|
||||
|
||||
for case in test_cases:
|
||||
n, d, im = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"num_meshes": n,
|
||||
"ico_level": 4,
|
||||
"image_size": im,
|
||||
"blur_radius": 1e-6,
|
||||
"faces_per_pixel": 40,
|
||||
"dist": d,
|
||||
}
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestRasterizeMeshes.bm_rasterize_meshes_with_clipping,
|
||||
"RASTERIZE_MESHES_CUDA_CLIPPING",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_rasterize_meshes()
|
||||
|
BIN
tests/data/room.jpg
Normal file
BIN
tests/data/room.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.8 KiB |
BIN
tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg
Normal file
BIN
tests/data/test_render_mesh_clipped_cam_dist=0.5.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.2 KiB |
@ -6,6 +6,8 @@ import unittest
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.renderer import FoVPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.mesh import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import (
|
||||
rasterize_meshes,
|
||||
rasterize_meshes_python,
|
||||
@ -1204,3 +1206,50 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
return rasterize
|
||||
|
||||
@staticmethod
|
||||
def bm_rasterize_meshes_with_clipping(
|
||||
num_meshes: int,
|
||||
ico_level: int,
|
||||
image_size: int,
|
||||
blur_radius: float,
|
||||
faces_per_pixel: int,
|
||||
dist: float,
|
||||
):
|
||||
device = get_random_cuda_device()
|
||||
meshes = ico_sphere(ico_level, device)
|
||||
meshes_batch = meshes.extend(num_meshes)
|
||||
|
||||
settings = RasterizationSettings(
|
||||
image_size=image_size,
|
||||
blur_radius=blur_radius,
|
||||
faces_per_pixel=faces_per_pixel,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
cull_to_frustum=True,
|
||||
)
|
||||
|
||||
# The camera is positioned so that the image plane intersects
|
||||
# the mesh and some faces are partially behind the image plane.
|
||||
R, T = look_at_view_transform(dist, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
|
||||
|
||||
# Transform the meshes to projec them onto the image plane
|
||||
meshes_screen = rasterizer.transform(meshes_batch)
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
def rasterize():
|
||||
# Only measure rasterization speed (including clipping)
|
||||
rasterize_meshes(
|
||||
meshes_screen,
|
||||
image_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
cull_to_frustum=True,
|
||||
)
|
||||
torch.cuda.synchronize(device)
|
||||
|
||||
return rasterize
|
||||
|
@ -245,6 +245,7 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase):
|
||||
image_size=image_size,
|
||||
blur_radius=1e-3,
|
||||
faces_per_pixel=10,
|
||||
z_clip_value=None,
|
||||
perspective_correct=False,
|
||||
),
|
||||
),
|
||||
|
@ -994,6 +994,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
blur_radius=0.0,
|
||||
faces_per_pixel=1,
|
||||
cull_backfaces=True,
|
||||
perspective_correct=False,
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
|
@ -9,14 +9,161 @@ See pytorch3d/renderer/mesh/clip.py for more details about the
|
||||
clipping process.
|
||||
"""
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.mesh import ClipFrustum, clip_faces
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from pytorch3d.io import save_obj
|
||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.renderer.lighting import PointLights
|
||||
from pytorch3d.renderer.mesh import (
|
||||
ClipFrustum,
|
||||
TexturesUV,
|
||||
clip_faces,
|
||||
convert_clipped_rasterization_to_original_faces,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.rasterize_meshes import _RasterizeFaceVerts
|
||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from pytorch3d.renderer.mesh.renderer import MeshRenderer
|
||||
from pytorch3d.renderer.mesh.shader import SoftPhongShader
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
# All saved images have prefix DEBUG_
|
||||
DEBUG = False
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
def load_cube_mesh_with_texture(self, device="cpu", with_grad: bool = False):
|
||||
verts = torch.tensor(
|
||||
[
|
||||
[-1, 1, 1],
|
||||
[1, 1, 1],
|
||||
[1, -1, 1],
|
||||
[-1, -1, 1],
|
||||
[-1, 1, -1],
|
||||
[1, 1, -1],
|
||||
[1, -1, -1],
|
||||
[-1, -1, -1],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
requires_grad=with_grad,
|
||||
)
|
||||
|
||||
# all faces correctly wound
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 4],
|
||||
[4, 1, 5],
|
||||
[1, 2, 5],
|
||||
[5, 2, 6],
|
||||
[2, 7, 6],
|
||||
[2, 3, 7],
|
||||
[3, 4, 7],
|
||||
[0, 4, 3],
|
||||
[4, 5, 6],
|
||||
[4, 6, 7],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
|
||||
verts_uvs = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0, 1],
|
||||
[1, 1],
|
||||
[1, 0],
|
||||
[0, 0],
|
||||
[0.204, 0.743],
|
||||
[0.781, 0.743],
|
||||
[0.781, 0.154],
|
||||
[0.204, 0.154],
|
||||
]
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float,
|
||||
)
|
||||
texture_map = load_rgb_image("room.jpg", DATA_DIR).to(device)
|
||||
textures = TexturesUV(
|
||||
maps=[texture_map], faces_uvs=faces.unsqueeze(0), verts_uvs=verts_uvs
|
||||
)
|
||||
mesh = Meshes([verts], [faces], textures=textures)
|
||||
if with_grad:
|
||||
return mesh, verts
|
||||
return mesh
|
||||
|
||||
def test_cube_mesh_render(self):
|
||||
"""
|
||||
End-End test of rendering a cube mesh with texture
|
||||
from decreasing camera distances. The camera starts
|
||||
outside the cube and enters the inside of the cube.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
mesh = self.load_cube_mesh_with_texture(device)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=1e-8,
|
||||
faces_per_pixel=5,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
bin_size=0,
|
||||
)
|
||||
|
||||
# Only ambient, no diffuse or specular
|
||||
lights = PointLights(
|
||||
device=device,
|
||||
ambient_color=((1.0, 1.0, 1.0),),
|
||||
diffuse_color=((0.0, 0.0, 0.0),),
|
||||
specular_color=((0.0, 0.0, 0.0),),
|
||||
location=[[0.0, 0.0, -3.0]],
|
||||
)
|
||||
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(raster_settings=raster_settings),
|
||||
shader=SoftPhongShader(device=device, lights=lights),
|
||||
)
|
||||
|
||||
# Render the cube by decreasing the distance from the camera until
|
||||
# the camera enters the cube. Check the output looks correct.
|
||||
images_list = []
|
||||
dists = np.linspace(0.1, 2.5, 20)[::-1]
|
||||
for d in dists:
|
||||
R, T = look_at_view_transform(d, 0, 0)
|
||||
T[0, 1] -= 0.1 # move down in the y axis
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
images = renderer(mesh, cameras=cameras)
|
||||
rgb = images[0, ..., :3].cpu().detach()
|
||||
filename = "DEBUG_cube_dist=%.1f.jpg" % d
|
||||
im = (rgb.numpy() * 255).astype(np.uint8)
|
||||
images_list.append(im)
|
||||
|
||||
# Check one of the images where the camera is inside the mesh
|
||||
if d == 0.5:
|
||||
filename = "test_render_mesh_clipped_cam_dist=0.5.jpg"
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
# Save a gif of the output - this should show
|
||||
# the camera moving inside the cube.
|
||||
if DEBUG:
|
||||
gif_filename = (
|
||||
"room_original.gif"
|
||||
if raster_settings.z_clip_value is None
|
||||
else "room_clipped.gif"
|
||||
)
|
||||
imageio.mimsave(DATA_DIR / gif_filename, images_list, fps=2)
|
||||
save_obj(
|
||||
f=DATA_DIR / "cube.obj",
|
||||
verts=mesh.verts_packed().cpu(),
|
||||
faces=mesh.faces_packed().cpu(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def clip_faces(meshes):
|
||||
verts_packed = meshes.verts_packed()
|
||||
@ -42,6 +189,34 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
return clipped_faces
|
||||
|
||||
def test_grad(self):
|
||||
"""
|
||||
Check that gradient flow is unaffected when the camera is inside the mesh
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
mesh, verts = self.load_cube_mesh_with_texture(device=device, with_grad=True)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=1e-5,
|
||||
faces_per_pixel=5,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
bin_size=0,
|
||||
)
|
||||
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(raster_settings=raster_settings),
|
||||
shader=SoftPhongShader(device=device),
|
||||
)
|
||||
dist = 0.4 # Camera is inside the cube
|
||||
R, T = look_at_view_transform(dist, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
images = renderer(mesh, cameras=cameras)
|
||||
images.sum().backward()
|
||||
|
||||
# Check gradients exist
|
||||
self.assertIsNotNone(verts.grad)
|
||||
|
||||
def test_case_1(self):
|
||||
"""
|
||||
Case 1: Single triangle fully in front of the image plane (z=0)
|
||||
@ -350,3 +525,134 @@ class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
# barycentric conversion matrix.
|
||||
bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6])
|
||||
self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx)
|
||||
|
||||
def test_convert_clipped_to_unclipped_case_4(self):
|
||||
"""
|
||||
Test with a single case 4 triangle which is clipped into
|
||||
a quadrilateral and subdivided.
|
||||
"""
|
||||
device = "cuda:0"
|
||||
# fmt: off
|
||||
verts = torch.tensor(
|
||||
[
|
||||
[-1.0, 0.0, -1.0], # noqa: E241, E201
|
||||
[ 0.0, 1.0, -1.0], # noqa: E241, E201
|
||||
[ 1.0, 0.0, -1.0], # noqa: E241, E201
|
||||
[ 0.0, -1.0, -1.0], # noqa: E241, E201
|
||||
[-1.0, 0.5, 0.5], # noqa: E241, E201
|
||||
[ 1.0, 1.0, 1.0], # noqa: E241, E201
|
||||
[ 0.0, -1.0, 1.0], # noqa: E241, E201
|
||||
[-1.0, 0.5, -0.5], # noqa: E241, E201
|
||||
[ 1.0, 1.0, -1.0], # noqa: E241, E201
|
||||
[-1.0, 0.0, 1.0], # noqa: E241, E201
|
||||
[ 0.0, 1.0, 1.0], # noqa: E241, E201
|
||||
[ 1.0, 0.0, 1.0], # noqa: E241, E201
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2], # noqa: E241, E201 Case 2 fully clipped
|
||||
[3, 4, 5], # noqa: E241, E201 Case 4 clipped and subdivided
|
||||
[5, 4, 3], # noqa: E241, E201 Repeat of Case 4
|
||||
[6, 7, 8], # noqa: E241, E201 Case 3 clipped
|
||||
[9, 10, 11], # noqa: E241, E201 Case 1 untouched
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
# fmt: on
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
|
||||
# Clip meshes
|
||||
clipped_faces = self.clip_faces(meshes)
|
||||
|
||||
# 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1)
|
||||
self.assertEqual(clipped_faces.face_verts.shape[0], 6)
|
||||
|
||||
image_size = (10, 10)
|
||||
blur_radius = 0.05
|
||||
faces_per_pixel = 2
|
||||
perspective_correct = True
|
||||
bin_size = 0
|
||||
max_faces_per_bin = 20
|
||||
clip_barycentric_coords = False
|
||||
cull_backfaces = False
|
||||
|
||||
# Rasterize clipped mesh
|
||||
pix_to_face, zbuf, barycentric_coords, dists = _RasterizeFaceVerts.apply(
|
||||
clipped_faces.face_verts,
|
||||
clipped_faces.mesh_to_face_first_idx,
|
||||
clipped_faces.num_faces_per_mesh,
|
||||
clipped_faces.clipped_faces_neighbor_idx,
|
||||
image_size,
|
||||
blur_radius,
|
||||
faces_per_pixel,
|
||||
bin_size,
|
||||
max_faces_per_bin,
|
||||
perspective_correct,
|
||||
clip_barycentric_coords,
|
||||
cull_backfaces,
|
||||
)
|
||||
|
||||
# Convert outputs so they are in terms of the unclipped mesh.
|
||||
outputs = convert_clipped_rasterization_to_original_faces(
|
||||
pix_to_face,
|
||||
barycentric_coords,
|
||||
clipped_faces,
|
||||
)
|
||||
pix_to_face_unclipped, barycentric_coords_unclipped = outputs
|
||||
|
||||
# In the clipped mesh there are more faces than in the unclipped mesh
|
||||
self.assertTrue(pix_to_face.max() > pix_to_face_unclipped.max())
|
||||
# Unclipped pix_to_face indices must be in the limit of the number
|
||||
# of faces in the unclipped mesh.
|
||||
self.assertTrue(pix_to_face_unclipped.max() < faces.shape[0])
|
||||
|
||||
def test_case_4_no_duplicates(self):
|
||||
"""
|
||||
In the case of an simple mesh with one face that is cut by the image
|
||||
plane into a quadrilateral, there shouldn't be duplicates indices of
|
||||
the face in the pix_to_face output of rasterization.
|
||||
"""
|
||||
for (device, bin_size) in [("cpu", 0), ("cuda:0", 0), ("cuda:0", None)]:
|
||||
verts = torch.tensor(
|
||||
[[0.0, -10.0, 1.0], [-1.0, 2.0, -2.0], [1.0, 5.0, -10.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
k = 3
|
||||
settings = RasterizationSettings(
|
||||
image_size=10,
|
||||
blur_radius=0.05,
|
||||
faces_per_pixel=k,
|
||||
z_clip_value=1e-2,
|
||||
perspective_correct=True,
|
||||
cull_to_frustum=True,
|
||||
bin_size=bin_size,
|
||||
)
|
||||
|
||||
# The camera is positioned so that the image plane cuts
|
||||
# the mesh face into a quadrilateral.
|
||||
R, T = look_at_view_transform(0.2, 0, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=90)
|
||||
rasterizer = MeshRasterizer(raster_settings=settings, cameras=cameras)
|
||||
fragments = rasterizer(meshes)
|
||||
|
||||
p2f = fragments.pix_to_face.reshape(-1, k)
|
||||
unique_vals, idx_counts = p2f.unique(dim=0, return_counts=True)
|
||||
# There is only one face in this mesh so if it hits a pixel
|
||||
# it can only be at position k = 0
|
||||
# For any pixel, the values [0, 0, 1] for the top K faces cannot be possible
|
||||
double_hit = torch.tensor([0, 0, -1], device=device)
|
||||
check_double_hit = any(torch.allclose(i, double_hit) for i in unique_vals)
|
||||
self.assertFalse(check_double_hit)
|
||||
|
Loading…
x
Reference in New Issue
Block a user