Multithread CPU naive mesh rasterization

Summary:
Threaded the for loop:
```
for (int yi = 0; yi < H; ++yi) {...}
```
in function `RasterizeMeshesNaiveCpu()`.
Chunk size is approx equal.

Reviewed By: bottler

Differential Revision: D40063604

fbshipit-source-id: 09150269405538119b0f1b029892179501421e68
This commit is contained in:
Gavin Peng 2022-10-06 06:42:58 -07:00 committed by Facebook GitHub Bot
parent 37bd280d19
commit 6471893f59
3 changed files with 121 additions and 47 deletions

View File

@ -10,7 +10,9 @@
#include <algorithm> #include <algorithm>
#include <list> #include <list>
#include <queue> #include <queue>
#include <thread>
#include <tuple> #include <tuple>
#include "ATen/core/TensorAccessor.h"
#include "rasterize_points/rasterization_utils.h" #include "rasterize_points/rasterization_utils.h"
#include "utils/geometry_utils.h" #include "utils/geometry_utils.h"
#include "utils/vec2.h" #include "utils/vec2.h"
@ -117,54 +119,28 @@ struct IsNeighbor {
int neighbor_idx; int neighbor_idx;
}; };
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> namespace {
RasterizeMeshesNaiveCpu( void RasterizeMeshesNaiveCpu_worker(
const torch::Tensor& face_verts, const int start_yi,
const int end_yi,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius, const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct, const bool perspective_correct,
const bool clip_barycentric_coords, const bool clip_barycentric_coords,
const bool cull_backfaces) { const bool cull_backfaces,
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || const int32_t N,
face_verts.size(2) != 3) { const int H,
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); const int W,
} const int K,
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) { at::TensorAccessor<float, 3>& face_verts_a,
AT_ERROR( at::TensorAccessor<float, 1>& face_areas_a,
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx"); at::TensorAccessor<float, 2>& face_bboxes_a,
} at::TensorAccessor<int64_t, 1>& neighbor_idx_a,
at::TensorAccessor<float, 4>& zbuf_a,
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size. at::TensorAccessor<int64_t, 4>& face_idxs_a,
const int H = std::get<0>(image_size); at::TensorAccessor<float, 4>& pix_dists_a,
const int W = std::get<1>(image_size); at::TensorAccessor<float, 5>& barycentric_coords_a) {
const int K = faces_per_pixel;
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
// Initialize output tensors.
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor barycentric_coords =
torch::full({N, H, W, K, 3}, -1, float_opts);
auto face_verts_a = face_verts.accessor<float, 3>();
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
auto face_areas = ComputeFaceAreas(face_verts);
auto face_areas_a = face_areas.accessor<float, 1>();
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
// Loop through each mesh in the batch. // Loop through each mesh in the batch.
// Get the start index of the faces in faces_packed and the num faces // Get the start index of the faces in faces_packed and the num faces
@ -174,7 +150,7 @@ RasterizeMeshesNaiveCpu(
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>()); (face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
// Iterate through the horizontal lines of the image from top to bottom. // Iterate through the horizontal lines of the image from top to bottom.
for (int yi = 0; yi < H; ++yi) { for (int yi = start_yi; yi < end_yi; ++yi) {
// Reverse the order of yi so that +Y is pointing upwards in the image. // Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = H - 1 - yi; const int yidx = H - 1 - yi;
@ -324,6 +300,92 @@ RasterizeMeshesNaiveCpu(
} }
} }
} }
}
} // namespace
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const torch::Tensor& clipped_faces_neighbor_idx,
const std::tuple<int, int> image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct,
const bool clip_barycentric_coords,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
AT_ERROR(
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
}
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
const int H = std::get<0>(image_size);
const int W = std::get<1>(image_size);
const int K = faces_per_pixel;
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
// Initialize output tensors.
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor barycentric_coords =
torch::full({N, H, W, K, 3}, -1, float_opts);
auto face_verts_a = face_verts.accessor<float, 3>();
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
auto face_areas = ComputeFaceAreas(face_verts);
auto face_areas_a = face_areas.accessor<float, 1>();
const int64_t n_threads = at::get_num_threads();
std::vector<std::thread> threads;
threads.reserve(n_threads);
const int chunk_size = 1 + (H - 1) / n_threads;
int start_yi = 0;
for (int iThread = 0; iThread < n_threads; ++iThread) {
const int64_t end_yi = std::min(start_yi + chunk_size, H);
threads.emplace_back(
RasterizeMeshesNaiveCpu_worker,
start_yi,
end_yi,
mesh_to_face_first_idx,
num_faces_per_mesh,
blur_radius,
perspective_correct,
clip_barycentric_coords,
cull_backfaces,
N,
H,
W,
K,
std::ref(face_verts_a),
std::ref(face_areas_a),
std::ref(face_bboxes_a),
std::ref(neighbor_idx_a),
std::ref(zbuf_a),
std::ref(face_idxs_a),
std::ref(pix_dists_a),
std::ref(barycentric_coords_a));
start_yi += chunk_size;
}
for (auto&& thread : threads) {
thread.join();
}
return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists); return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
} }

View File

@ -4,13 +4,15 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os
from itertools import product from itertools import product
import torch import torch
from fvcore.common.benchmark import benchmark from fvcore.common.benchmark import benchmark
from tests.test_rasterize_meshes import TestRasterizeMeshes from tests.test_rasterize_meshes import TestRasterizeMeshes
BM_RASTERIZE_MESHES_N_THREADS = os.getenv("BM_RASTERIZE_MESHES_N_THREADS", 1)
torch.set_num_threads(int(BM_RASTERIZE_MESHES_N_THREADS))
# ico levels: # ico levels:
# 0: (12 verts, 20 faces) # 0: (12 verts, 20 faces)
@ -41,7 +43,7 @@ def bm_rasterize_meshes() -> None:
kwargs_list = [] kwargs_list = []
num_meshes = [1] num_meshes = [1]
ico_level = [1] ico_level = [1]
image_size = [64, 128] image_size = [64, 128, 512]
blur = [1e-6] blur = [1e-6]
faces_per_pixel = [3, 50] faces_per_pixel = [3, 50]
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel) test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)

View File

@ -35,7 +35,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1) self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1) self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
def test_simple_cpu_naive(self): def _test_simple_cpu_naive_instance(self):
device = torch.device("cpu") device = torch.device("cpu")
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0) self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0) self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
@ -43,6 +43,16 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._test_perspective_correct(rasterize_meshes, device, bin_size=0) self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
self._test_back_face_culling(rasterize_meshes, device, bin_size=0) self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
def test_simple_cpu_naive(self):
n_threads = torch.get_num_threads()
torch.set_num_threads(1) # single threaded
self._test_simple_cpu_naive_instance()
torch.set_num_threads(4) # even (divisible) number of threads
self._test_simple_cpu_naive_instance()
torch.set_num_threads(5) # odd (nondivisible) number of threads
self._test_simple_cpu_naive_instance()
torch.set_num_threads(n_threads)
def test_simple_cuda_naive(self): def test_simple_cuda_naive(self):
device = get_random_cuda_device() device = get_random_cuda_device()
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0) self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)