mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Multithread CPU naive mesh rasterization
Summary: Threaded the for loop: ``` for (int yi = 0; yi < H; ++yi) {...} ``` in function `RasterizeMeshesNaiveCpu()`. Chunk size is approx equal. Reviewed By: bottler Differential Revision: D40063604 fbshipit-source-id: 09150269405538119b0f1b029892179501421e68
This commit is contained in:
parent
37bd280d19
commit
6471893f59
@ -10,7 +10,9 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
|
#include <thread>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
#include "ATen/core/TensorAccessor.h"
|
||||||
#include "rasterize_points/rasterization_utils.h"
|
#include "rasterize_points/rasterization_utils.h"
|
||||||
#include "utils/geometry_utils.h"
|
#include "utils/geometry_utils.h"
|
||||||
#include "utils/vec2.h"
|
#include "utils/vec2.h"
|
||||||
@ -117,54 +119,28 @@ struct IsNeighbor {
|
|||||||
int neighbor_idx;
|
int neighbor_idx;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
namespace {
|
||||||
RasterizeMeshesNaiveCpu(
|
void RasterizeMeshesNaiveCpu_worker(
|
||||||
const torch::Tensor& face_verts,
|
const int start_yi,
|
||||||
|
const int end_yi,
|
||||||
const torch::Tensor& mesh_to_face_first_idx,
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
const torch::Tensor& num_faces_per_mesh,
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
const torch::Tensor& clipped_faces_neighbor_idx,
|
|
||||||
const std::tuple<int, int> image_size,
|
|
||||||
const float blur_radius,
|
const float blur_radius,
|
||||||
const int faces_per_pixel,
|
|
||||||
const bool perspective_correct,
|
const bool perspective_correct,
|
||||||
const bool clip_barycentric_coords,
|
const bool clip_barycentric_coords,
|
||||||
const bool cull_backfaces) {
|
const bool cull_backfaces,
|
||||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
const int32_t N,
|
||||||
face_verts.size(2) != 3) {
|
const int H,
|
||||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
const int W,
|
||||||
}
|
const int K,
|
||||||
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
|
at::TensorAccessor<float, 3>& face_verts_a,
|
||||||
AT_ERROR(
|
at::TensorAccessor<float, 1>& face_areas_a,
|
||||||
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
|
at::TensorAccessor<float, 2>& face_bboxes_a,
|
||||||
}
|
at::TensorAccessor<int64_t, 1>& neighbor_idx_a,
|
||||||
|
at::TensorAccessor<float, 4>& zbuf_a,
|
||||||
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
|
at::TensorAccessor<int64_t, 4>& face_idxs_a,
|
||||||
const int H = std::get<0>(image_size);
|
at::TensorAccessor<float, 4>& pix_dists_a,
|
||||||
const int W = std::get<1>(image_size);
|
at::TensorAccessor<float, 5>& barycentric_coords_a) {
|
||||||
const int K = faces_per_pixel;
|
|
||||||
|
|
||||||
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
|
|
||||||
auto float_opts = face_verts.options().dtype(torch::kFloat32);
|
|
||||||
|
|
||||||
// Initialize output tensors.
|
|
||||||
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
|
|
||||||
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
|
|
||||||
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
|
|
||||||
torch::Tensor barycentric_coords =
|
|
||||||
torch::full({N, H, W, K, 3}, -1, float_opts);
|
|
||||||
|
|
||||||
auto face_verts_a = face_verts.accessor<float, 3>();
|
|
||||||
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
|
|
||||||
auto zbuf_a = zbuf.accessor<float, 4>();
|
|
||||||
auto pix_dists_a = pix_dists.accessor<float, 4>();
|
|
||||||
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
|
|
||||||
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
|
|
||||||
|
|
||||||
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
|
|
||||||
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
|
|
||||||
auto face_areas = ComputeFaceAreas(face_verts);
|
|
||||||
auto face_areas_a = face_areas.accessor<float, 1>();
|
|
||||||
|
|
||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
// Loop through each mesh in the batch.
|
// Loop through each mesh in the batch.
|
||||||
// Get the start index of the faces in faces_packed and the num faces
|
// Get the start index of the faces in faces_packed and the num faces
|
||||||
@ -174,7 +150,7 @@ RasterizeMeshesNaiveCpu(
|
|||||||
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
|
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
|
||||||
|
|
||||||
// Iterate through the horizontal lines of the image from top to bottom.
|
// Iterate through the horizontal lines of the image from top to bottom.
|
||||||
for (int yi = 0; yi < H; ++yi) {
|
for (int yi = start_yi; yi < end_yi; ++yi) {
|
||||||
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||||
const int yidx = H - 1 - yi;
|
const int yidx = H - 1 - yi;
|
||||||
|
|
||||||
@ -324,6 +300,92 @@ RasterizeMeshesNaiveCpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
||||||
|
RasterizeMeshesNaiveCpu(
|
||||||
|
const torch::Tensor& face_verts,
|
||||||
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
|
const torch::Tensor& clipped_faces_neighbor_idx,
|
||||||
|
const std::tuple<int, int> image_size,
|
||||||
|
const float blur_radius,
|
||||||
|
const int faces_per_pixel,
|
||||||
|
const bool perspective_correct,
|
||||||
|
const bool clip_barycentric_coords,
|
||||||
|
const bool cull_backfaces) {
|
||||||
|
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||||
|
face_verts.size(2) != 3) {
|
||||||
|
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||||
|
}
|
||||||
|
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
|
||||||
|
AT_ERROR(
|
||||||
|
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
|
||||||
|
const int H = std::get<0>(image_size);
|
||||||
|
const int W = std::get<1>(image_size);
|
||||||
|
const int K = faces_per_pixel;
|
||||||
|
|
||||||
|
auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
|
||||||
|
auto float_opts = face_verts.options().dtype(torch::kFloat32);
|
||||||
|
|
||||||
|
// Initialize output tensors.
|
||||||
|
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
|
||||||
|
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
|
||||||
|
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
|
||||||
|
torch::Tensor barycentric_coords =
|
||||||
|
torch::full({N, H, W, K, 3}, -1, float_opts);
|
||||||
|
|
||||||
|
auto face_verts_a = face_verts.accessor<float, 3>();
|
||||||
|
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
|
||||||
|
auto zbuf_a = zbuf.accessor<float, 4>();
|
||||||
|
auto pix_dists_a = pix_dists.accessor<float, 4>();
|
||||||
|
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
|
||||||
|
auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
|
||||||
|
|
||||||
|
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
|
||||||
|
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
|
||||||
|
auto face_areas = ComputeFaceAreas(face_verts);
|
||||||
|
auto face_areas_a = face_areas.accessor<float, 1>();
|
||||||
|
|
||||||
|
const int64_t n_threads = at::get_num_threads();
|
||||||
|
std::vector<std::thread> threads;
|
||||||
|
threads.reserve(n_threads);
|
||||||
|
const int chunk_size = 1 + (H - 1) / n_threads;
|
||||||
|
int start_yi = 0;
|
||||||
|
for (int iThread = 0; iThread < n_threads; ++iThread) {
|
||||||
|
const int64_t end_yi = std::min(start_yi + chunk_size, H);
|
||||||
|
threads.emplace_back(
|
||||||
|
RasterizeMeshesNaiveCpu_worker,
|
||||||
|
start_yi,
|
||||||
|
end_yi,
|
||||||
|
mesh_to_face_first_idx,
|
||||||
|
num_faces_per_mesh,
|
||||||
|
blur_radius,
|
||||||
|
perspective_correct,
|
||||||
|
clip_barycentric_coords,
|
||||||
|
cull_backfaces,
|
||||||
|
N,
|
||||||
|
H,
|
||||||
|
W,
|
||||||
|
K,
|
||||||
|
std::ref(face_verts_a),
|
||||||
|
std::ref(face_areas_a),
|
||||||
|
std::ref(face_bboxes_a),
|
||||||
|
std::ref(neighbor_idx_a),
|
||||||
|
std::ref(zbuf_a),
|
||||||
|
std::ref(face_idxs_a),
|
||||||
|
std::ref(pix_dists_a),
|
||||||
|
std::ref(barycentric_coords_a));
|
||||||
|
start_yi += chunk_size;
|
||||||
|
}
|
||||||
|
for (auto&& thread : threads) {
|
||||||
|
thread.join();
|
||||||
|
}
|
||||||
|
|
||||||
return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
|
return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,13 +4,15 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
from itertools import product
|
from itertools import product
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fvcore.common.benchmark import benchmark
|
from fvcore.common.benchmark import benchmark
|
||||||
from tests.test_rasterize_meshes import TestRasterizeMeshes
|
from tests.test_rasterize_meshes import TestRasterizeMeshes
|
||||||
|
|
||||||
|
BM_RASTERIZE_MESHES_N_THREADS = os.getenv("BM_RASTERIZE_MESHES_N_THREADS", 1)
|
||||||
|
torch.set_num_threads(int(BM_RASTERIZE_MESHES_N_THREADS))
|
||||||
|
|
||||||
# ico levels:
|
# ico levels:
|
||||||
# 0: (12 verts, 20 faces)
|
# 0: (12 verts, 20 faces)
|
||||||
@ -41,7 +43,7 @@ def bm_rasterize_meshes() -> None:
|
|||||||
kwargs_list = []
|
kwargs_list = []
|
||||||
num_meshes = [1]
|
num_meshes = [1]
|
||||||
ico_level = [1]
|
ico_level = [1]
|
||||||
image_size = [64, 128]
|
image_size = [64, 128, 512]
|
||||||
blur = [1e-6]
|
blur = [1e-6]
|
||||||
faces_per_pixel = [3, 50]
|
faces_per_pixel = [3, 50]
|
||||||
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
|
test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
|
||||||
|
@ -35,7 +35,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
|
self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
|
||||||
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
|
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
|
||||||
|
|
||||||
def test_simple_cpu_naive(self):
|
def _test_simple_cpu_naive_instance(self):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
|
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
|
||||||
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
|
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
|
||||||
@ -43,6 +43,16 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
|
self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
|
||||||
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
|
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
|
||||||
|
|
||||||
|
def test_simple_cpu_naive(self):
|
||||||
|
n_threads = torch.get_num_threads()
|
||||||
|
torch.set_num_threads(1) # single threaded
|
||||||
|
self._test_simple_cpu_naive_instance()
|
||||||
|
torch.set_num_threads(4) # even (divisible) number of threads
|
||||||
|
self._test_simple_cpu_naive_instance()
|
||||||
|
torch.set_num_threads(5) # odd (nondivisible) number of threads
|
||||||
|
self._test_simple_cpu_naive_instance()
|
||||||
|
torch.set_num_threads(n_threads)
|
||||||
|
|
||||||
def test_simple_cuda_naive(self):
|
def test_simple_cuda_naive(self):
|
||||||
device = get_random_cuda_device()
|
device = get_random_cuda_device()
|
||||||
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
|
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user