mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Multithread CPU naive mesh rasterization
Summary:
Threaded the for loop:
```
for (int yi = 0; yi < H; ++yi) {...}
```
in function `RasterizeMeshesNaiveCpu()`.
Chunk size is approx equal.
Reviewed By: bottler
Differential Revision: D40063604
fbshipit-source-id: 09150269405538119b0f1b029892179501421e68
			
			
This commit is contained in:
		
							parent
							
								
									37bd280d19
								
							
						
					
					
						commit
						6471893f59
					
				@ -10,7 +10,9 @@
 | 
				
			|||||||
#include <algorithm>
 | 
					#include <algorithm>
 | 
				
			||||||
#include <list>
 | 
					#include <list>
 | 
				
			||||||
#include <queue>
 | 
					#include <queue>
 | 
				
			||||||
 | 
					#include <thread>
 | 
				
			||||||
#include <tuple>
 | 
					#include <tuple>
 | 
				
			||||||
 | 
					#include "ATen/core/TensorAccessor.h"
 | 
				
			||||||
#include "rasterize_points/rasterization_utils.h"
 | 
					#include "rasterize_points/rasterization_utils.h"
 | 
				
			||||||
#include "utils/geometry_utils.h"
 | 
					#include "utils/geometry_utils.h"
 | 
				
			||||||
#include "utils/vec2.h"
 | 
					#include "utils/vec2.h"
 | 
				
			||||||
@ -117,54 +119,28 @@ struct IsNeighbor {
 | 
				
			|||||||
  int neighbor_idx;
 | 
					  int neighbor_idx;
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
 | 
					namespace {
 | 
				
			||||||
RasterizeMeshesNaiveCpu(
 | 
					void RasterizeMeshesNaiveCpu_worker(
 | 
				
			||||||
    const torch::Tensor& face_verts,
 | 
					    const int start_yi,
 | 
				
			||||||
 | 
					    const int end_yi,
 | 
				
			||||||
    const torch::Tensor& mesh_to_face_first_idx,
 | 
					    const torch::Tensor& mesh_to_face_first_idx,
 | 
				
			||||||
    const torch::Tensor& num_faces_per_mesh,
 | 
					    const torch::Tensor& num_faces_per_mesh,
 | 
				
			||||||
    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
					 | 
				
			||||||
    const std::tuple<int, int> image_size,
 | 
					 | 
				
			||||||
    const float blur_radius,
 | 
					    const float blur_radius,
 | 
				
			||||||
    const int faces_per_pixel,
 | 
					 | 
				
			||||||
    const bool perspective_correct,
 | 
					    const bool perspective_correct,
 | 
				
			||||||
    const bool clip_barycentric_coords,
 | 
					    const bool clip_barycentric_coords,
 | 
				
			||||||
    const bool cull_backfaces) {
 | 
					    const bool cull_backfaces,
 | 
				
			||||||
  if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
 | 
					    const int32_t N,
 | 
				
			||||||
      face_verts.size(2) != 3) {
 | 
					    const int H,
 | 
				
			||||||
    AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
 | 
					    const int W,
 | 
				
			||||||
  }
 | 
					    const int K,
 | 
				
			||||||
  if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
 | 
					    at::TensorAccessor<float, 3>& face_verts_a,
 | 
				
			||||||
    AT_ERROR(
 | 
					    at::TensorAccessor<float, 1>& face_areas_a,
 | 
				
			||||||
        "num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
 | 
					    at::TensorAccessor<float, 2>& face_bboxes_a,
 | 
				
			||||||
  }
 | 
					    at::TensorAccessor<int64_t, 1>& neighbor_idx_a,
 | 
				
			||||||
 | 
					    at::TensorAccessor<float, 4>& zbuf_a,
 | 
				
			||||||
  const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
 | 
					    at::TensorAccessor<int64_t, 4>& face_idxs_a,
 | 
				
			||||||
  const int H = std::get<0>(image_size);
 | 
					    at::TensorAccessor<float, 4>& pix_dists_a,
 | 
				
			||||||
  const int W = std::get<1>(image_size);
 | 
					    at::TensorAccessor<float, 5>& barycentric_coords_a) {
 | 
				
			||||||
  const int K = faces_per_pixel;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
 | 
					 | 
				
			||||||
  auto float_opts = face_verts.options().dtype(torch::kFloat32);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  // Initialize output tensors.
 | 
					 | 
				
			||||||
  torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
 | 
					 | 
				
			||||||
  torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
 | 
					 | 
				
			||||||
  torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
 | 
					 | 
				
			||||||
  torch::Tensor barycentric_coords =
 | 
					 | 
				
			||||||
      torch::full({N, H, W, K, 3}, -1, float_opts);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  auto face_verts_a = face_verts.accessor<float, 3>();
 | 
					 | 
				
			||||||
  auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
 | 
					 | 
				
			||||||
  auto zbuf_a = zbuf.accessor<float, 4>();
 | 
					 | 
				
			||||||
  auto pix_dists_a = pix_dists.accessor<float, 4>();
 | 
					 | 
				
			||||||
  auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
 | 
					 | 
				
			||||||
  auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
 | 
					 | 
				
			||||||
  auto face_bboxes_a = face_bboxes.accessor<float, 2>();
 | 
					 | 
				
			||||||
  auto face_areas = ComputeFaceAreas(face_verts);
 | 
					 | 
				
			||||||
  auto face_areas_a = face_areas.accessor<float, 1>();
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  for (int n = 0; n < N; ++n) {
 | 
					  for (int n = 0; n < N; ++n) {
 | 
				
			||||||
    // Loop through each mesh in the batch.
 | 
					    // Loop through each mesh in the batch.
 | 
				
			||||||
    // Get the start index of the faces in faces_packed and the num faces
 | 
					    // Get the start index of the faces in faces_packed and the num faces
 | 
				
			||||||
@ -174,7 +150,7 @@ RasterizeMeshesNaiveCpu(
 | 
				
			|||||||
        (face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
 | 
					        (face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Iterate through the horizontal lines of the image from top to bottom.
 | 
					    // Iterate through the horizontal lines of the image from top to bottom.
 | 
				
			||||||
    for (int yi = 0; yi < H; ++yi) {
 | 
					    for (int yi = start_yi; yi < end_yi; ++yi) {
 | 
				
			||||||
      // Reverse the order of yi so that +Y is pointing upwards in the image.
 | 
					      // Reverse the order of yi so that +Y is pointing upwards in the image.
 | 
				
			||||||
      const int yidx = H - 1 - yi;
 | 
					      const int yidx = H - 1 - yi;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -324,6 +300,92 @@ RasterizeMeshesNaiveCpu(
 | 
				
			|||||||
      }
 | 
					      }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					} // namespace
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
 | 
				
			||||||
 | 
					RasterizeMeshesNaiveCpu(
 | 
				
			||||||
 | 
					    const torch::Tensor& face_verts,
 | 
				
			||||||
 | 
					    const torch::Tensor& mesh_to_face_first_idx,
 | 
				
			||||||
 | 
					    const torch::Tensor& num_faces_per_mesh,
 | 
				
			||||||
 | 
					    const torch::Tensor& clipped_faces_neighbor_idx,
 | 
				
			||||||
 | 
					    const std::tuple<int, int> image_size,
 | 
				
			||||||
 | 
					    const float blur_radius,
 | 
				
			||||||
 | 
					    const int faces_per_pixel,
 | 
				
			||||||
 | 
					    const bool perspective_correct,
 | 
				
			||||||
 | 
					    const bool clip_barycentric_coords,
 | 
				
			||||||
 | 
					    const bool cull_backfaces) {
 | 
				
			||||||
 | 
					  if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
 | 
				
			||||||
 | 
					      face_verts.size(2) != 3) {
 | 
				
			||||||
 | 
					    AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
 | 
				
			||||||
 | 
					    AT_ERROR(
 | 
				
			||||||
 | 
					        "num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
 | 
				
			||||||
 | 
					  const int H = std::get<0>(image_size);
 | 
				
			||||||
 | 
					  const int W = std::get<1>(image_size);
 | 
				
			||||||
 | 
					  const int K = faces_per_pixel;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto long_opts = num_faces_per_mesh.options().dtype(torch::kInt64);
 | 
				
			||||||
 | 
					  auto float_opts = face_verts.options().dtype(torch::kFloat32);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  // Initialize output tensors.
 | 
				
			||||||
 | 
					  torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
 | 
				
			||||||
 | 
					  torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
 | 
				
			||||||
 | 
					  torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
 | 
				
			||||||
 | 
					  torch::Tensor barycentric_coords =
 | 
				
			||||||
 | 
					      torch::full({N, H, W, K, 3}, -1, float_opts);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto face_verts_a = face_verts.accessor<float, 3>();
 | 
				
			||||||
 | 
					  auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
 | 
				
			||||||
 | 
					  auto zbuf_a = zbuf.accessor<float, 4>();
 | 
				
			||||||
 | 
					  auto pix_dists_a = pix_dists.accessor<float, 4>();
 | 
				
			||||||
 | 
					  auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
 | 
				
			||||||
 | 
					  auto neighbor_idx_a = clipped_faces_neighbor_idx.accessor<int64_t, 1>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
 | 
				
			||||||
 | 
					  auto face_bboxes_a = face_bboxes.accessor<float, 2>();
 | 
				
			||||||
 | 
					  auto face_areas = ComputeFaceAreas(face_verts);
 | 
				
			||||||
 | 
					  auto face_areas_a = face_areas.accessor<float, 1>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  const int64_t n_threads = at::get_num_threads();
 | 
				
			||||||
 | 
					  std::vector<std::thread> threads;
 | 
				
			||||||
 | 
					  threads.reserve(n_threads);
 | 
				
			||||||
 | 
					  const int chunk_size = 1 + (H - 1) / n_threads;
 | 
				
			||||||
 | 
					  int start_yi = 0;
 | 
				
			||||||
 | 
					  for (int iThread = 0; iThread < n_threads; ++iThread) {
 | 
				
			||||||
 | 
					    const int64_t end_yi = std::min(start_yi + chunk_size, H);
 | 
				
			||||||
 | 
					    threads.emplace_back(
 | 
				
			||||||
 | 
					        RasterizeMeshesNaiveCpu_worker,
 | 
				
			||||||
 | 
					        start_yi,
 | 
				
			||||||
 | 
					        end_yi,
 | 
				
			||||||
 | 
					        mesh_to_face_first_idx,
 | 
				
			||||||
 | 
					        num_faces_per_mesh,
 | 
				
			||||||
 | 
					        blur_radius,
 | 
				
			||||||
 | 
					        perspective_correct,
 | 
				
			||||||
 | 
					        clip_barycentric_coords,
 | 
				
			||||||
 | 
					        cull_backfaces,
 | 
				
			||||||
 | 
					        N,
 | 
				
			||||||
 | 
					        H,
 | 
				
			||||||
 | 
					        W,
 | 
				
			||||||
 | 
					        K,
 | 
				
			||||||
 | 
					        std::ref(face_verts_a),
 | 
				
			||||||
 | 
					        std::ref(face_areas_a),
 | 
				
			||||||
 | 
					        std::ref(face_bboxes_a),
 | 
				
			||||||
 | 
					        std::ref(neighbor_idx_a),
 | 
				
			||||||
 | 
					        std::ref(zbuf_a),
 | 
				
			||||||
 | 
					        std::ref(face_idxs_a),
 | 
				
			||||||
 | 
					        std::ref(pix_dists_a),
 | 
				
			||||||
 | 
					        std::ref(barycentric_coords_a));
 | 
				
			||||||
 | 
					    start_yi += chunk_size;
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  for (auto&& thread : threads) {
 | 
				
			||||||
 | 
					    thread.join();
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
 | 
					  return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -4,13 +4,15 @@
 | 
				
			|||||||
# This source code is licensed under the BSD-style license found in the
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
# LICENSE file in the root directory of this source tree.
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
from itertools import product
 | 
					from itertools import product
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from fvcore.common.benchmark import benchmark
 | 
					from fvcore.common.benchmark import benchmark
 | 
				
			||||||
from tests.test_rasterize_meshes import TestRasterizeMeshes
 | 
					from tests.test_rasterize_meshes import TestRasterizeMeshes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					BM_RASTERIZE_MESHES_N_THREADS = os.getenv("BM_RASTERIZE_MESHES_N_THREADS", 1)
 | 
				
			||||||
 | 
					torch.set_num_threads(int(BM_RASTERIZE_MESHES_N_THREADS))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# ico levels:
 | 
					# ico levels:
 | 
				
			||||||
# 0: (12 verts, 20 faces)
 | 
					# 0: (12 verts, 20 faces)
 | 
				
			||||||
@ -41,7 +43,7 @@ def bm_rasterize_meshes() -> None:
 | 
				
			|||||||
    kwargs_list = []
 | 
					    kwargs_list = []
 | 
				
			||||||
    num_meshes = [1]
 | 
					    num_meshes = [1]
 | 
				
			||||||
    ico_level = [1]
 | 
					    ico_level = [1]
 | 
				
			||||||
    image_size = [64, 128]
 | 
					    image_size = [64, 128, 512]
 | 
				
			||||||
    blur = [1e-6]
 | 
					    blur = [1e-6]
 | 
				
			||||||
    faces_per_pixel = [3, 50]
 | 
					    faces_per_pixel = [3, 50]
 | 
				
			||||||
    test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
 | 
					    test_cases = product(num_meshes, ico_level, image_size, blur, faces_per_pixel)
 | 
				
			||||||
 | 
				
			|||||||
@ -35,7 +35,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
        self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
 | 
					        self._test_barycentric_clipping(rasterize_meshes_python, device, bin_size=-1)
 | 
				
			||||||
        self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
 | 
					        self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_simple_cpu_naive(self):
 | 
					    def _test_simple_cpu_naive_instance(self):
 | 
				
			||||||
        device = torch.device("cpu")
 | 
					        device = torch.device("cpu")
 | 
				
			||||||
        self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
 | 
					        self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
 | 
				
			||||||
        self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
 | 
					        self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
 | 
				
			||||||
@ -43,6 +43,16 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
        self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
 | 
					        self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
 | 
				
			||||||
        self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
 | 
					        self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_simple_cpu_naive(self):
 | 
				
			||||||
 | 
					        n_threads = torch.get_num_threads()
 | 
				
			||||||
 | 
					        torch.set_num_threads(1)  # single threaded
 | 
				
			||||||
 | 
					        self._test_simple_cpu_naive_instance()
 | 
				
			||||||
 | 
					        torch.set_num_threads(4)  # even (divisible) number of threads
 | 
				
			||||||
 | 
					        self._test_simple_cpu_naive_instance()
 | 
				
			||||||
 | 
					        torch.set_num_threads(5)  # odd (nondivisible) number of threads
 | 
				
			||||||
 | 
					        self._test_simple_cpu_naive_instance()
 | 
				
			||||||
 | 
					        torch.set_num_threads(n_threads)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_simple_cuda_naive(self):
 | 
					    def test_simple_cuda_naive(self):
 | 
				
			||||||
        device = get_random_cuda_device()
 | 
					        device = get_random_cuda_device()
 | 
				
			||||||
        self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
 | 
					        self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user