mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	coarse rasterization bug fix
Summary: Fix a bug which resulted in a rendering artifacts if the image size was not a multiple of 16. Fix: Revert coarse rasterization to original implementation and only update fine rasterization to reverse the ordering of Y and X axis. This is much simpler than the previous approach! Additional changes: - updated mesh rendering end-end tests to check outputs from both naive and coarse to fine rasterization. - added pointcloud rendering end-end tests Reviewed By: gkioxari Differential Revision: D21102725 fbshipit-source-id: 2e7e1b013dd6dd12b3a00b79eb8167deddb2e89a
This commit is contained in:
		
							parent
							
								
									1e4749602d
								
							
						
					
					
						commit
						9ef1ee8455
					
				
										
											Binary file not shown.
										
									
								
							| 
		 Before Width: | Height: | Size: 62 KiB After Width: | Height: | Size: 64 KiB  | 
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							@ -556,18 +556,16 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
 | 
			
		||||
        // PixToNdc gives the location of the center of each pixel, so we
 | 
			
		||||
        // need to add/subtract a half pixel to get the true extent of the bin.
 | 
			
		||||
        // Reverse ordering of Y axis so that +Y is upwards in the image.
 | 
			
		||||
        const int yidx = num_bins - by;
 | 
			
		||||
        const float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
 | 
			
		||||
        const float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;
 | 
			
		||||
 | 
			
		||||
        const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
 | 
			
		||||
        const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
 | 
			
		||||
        const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
 | 
			
		||||
 | 
			
		||||
        for (int bx = 0; bx < num_bins; ++bx) {
 | 
			
		||||
          // X coordinate of the left and right of the bin.
 | 
			
		||||
          // Reverse ordering of x axis so that +X is left.
 | 
			
		||||
          const int xidx = num_bins - bx;
 | 
			
		||||
          const float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
 | 
			
		||||
          const float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;
 | 
			
		||||
          const float bin_x_max =
 | 
			
		||||
              PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
 | 
			
		||||
          const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
 | 
			
		||||
 | 
			
		||||
          const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
 | 
			
		||||
          if (y_overlap && x_overlap) {
 | 
			
		||||
@ -629,6 +627,7 @@ torch::Tensor RasterizeMeshesCoarseCuda(
 | 
			
		||||
  const int N = num_faces_per_mesh.size(0);
 | 
			
		||||
  const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
 | 
			
		||||
  const int M = max_faces_per_bin;
 | 
			
		||||
 | 
			
		||||
  if (num_bins >= 22) {
 | 
			
		||||
    std::stringstream ss;
 | 
			
		||||
    ss << "Got " << num_bins << "; that's too many!";
 | 
			
		||||
@ -702,13 +701,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
 | 
			
		||||
    if (yi >= H || xi >= W)
 | 
			
		||||
      continue;
 | 
			
		||||
 | 
			
		||||
    // Reverse ordering of the X and Y axis so that
 | 
			
		||||
    // in the image +Y is pointing up and +X is pointing left.
 | 
			
		||||
    const int yidx = H - 1 - yi;
 | 
			
		||||
    const int xidx = W - 1 - xi;
 | 
			
		||||
 | 
			
		||||
    const float xf = PixToNdc(xidx, W);
 | 
			
		||||
    const float yf = PixToNdc(yidx, H);
 | 
			
		||||
    const float xf = PixToNdc(xi, W);
 | 
			
		||||
    const float yf = PixToNdc(yi, H);
 | 
			
		||||
    const float2 pxy = make_float2(xf, yf);
 | 
			
		||||
 | 
			
		||||
    // This part looks like the naive rasterization kernel, except we use
 | 
			
		||||
@ -743,7 +737,12 @@ __global__ void RasterizeMeshesFineCudaKernel(
 | 
			
		||||
    // output for the current pixel.
 | 
			
		||||
    // TODO: make sorting an option as only top k is needed, not sorted values.
 | 
			
		||||
    BubbleSort(q, q_size);
 | 
			
		||||
    const int pix_idx = n * H * W * K + yi * H * K + xi * K;
 | 
			
		||||
 | 
			
		||||
    // Reverse ordering of the X and Y axis so that
 | 
			
		||||
    // in the image +Y is pointing up and +X is pointing left.
 | 
			
		||||
    const int yidx = H - 1 - yi;
 | 
			
		||||
    const int xidx = W - 1 - xi;
 | 
			
		||||
    const int pix_idx = n * H * W * K + yidx * H * K + xidx * K;
 | 
			
		||||
    for (int k = 0; k < q_size; k++) {
 | 
			
		||||
      face_idxs[pix_idx + k] = q[k].idx;
 | 
			
		||||
      zbuf[pix_idx + k] = q[k].z;
 | 
			
		||||
 | 
			
		||||
@ -430,13 +430,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
 | 
			
		||||
    const int face_stop_idx =
 | 
			
		||||
        (face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
 | 
			
		||||
 | 
			
		||||
    float bin_y_max = 1.0f;
 | 
			
		||||
    float bin_y_min = bin_y_max - bin_width;
 | 
			
		||||
    float bin_y_min = -1.0f;
 | 
			
		||||
    float bin_y_max = bin_y_min + bin_width;
 | 
			
		||||
 | 
			
		||||
    // Iterate through the horizontal bins from top to bottom.
 | 
			
		||||
    for (int by = 0; by < BH; ++by) {
 | 
			
		||||
      float bin_x_max = 1.0f;
 | 
			
		||||
      float bin_x_min = bin_x_max - bin_width;
 | 
			
		||||
      float bin_x_min = -1.0f;
 | 
			
		||||
      float bin_x_max = bin_x_min + bin_width;
 | 
			
		||||
 | 
			
		||||
      // Iterate through bins on this horizontal line, left to right.
 | 
			
		||||
      for (int bx = 0; bx < BW; ++bx) {
 | 
			
		||||
@ -473,13 +473,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Shift the bin to the left for the next loop iteration.
 | 
			
		||||
        bin_x_max = bin_x_min;
 | 
			
		||||
        bin_x_min = bin_x_min - bin_width;
 | 
			
		||||
        // Shift the bin to the right for the next loop iteration
 | 
			
		||||
        bin_x_min = bin_x_max;
 | 
			
		||||
        bin_x_max = bin_x_min + bin_width;
 | 
			
		||||
      }
 | 
			
		||||
      // Shift the bin down for the next loop iteration.
 | 
			
		||||
      bin_y_max = bin_y_min;
 | 
			
		||||
      bin_y_min = bin_y_min - bin_width;
 | 
			
		||||
      // Shift the bin down for the next loop iteration
 | 
			
		||||
      bin_y_min = bin_y_max;
 | 
			
		||||
      bin_y_max = bin_y_min + bin_width;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return bin_faces;
 | 
			
		||||
 | 
			
		||||
@ -95,7 +95,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
 | 
			
		||||
    const int n = i / (S * S); // Batch index
 | 
			
		||||
    const int pix_idx = i % (S * S);
 | 
			
		||||
 | 
			
		||||
    // Reverse ordering of X and Y axes.
 | 
			
		||||
    // Reverse ordering of the X and Y axis as the camera coordinates
 | 
			
		||||
    // assume that +Y is pointing up and +X is pointing left.
 | 
			
		||||
    const int yi = S - 1 - pix_idx / S;
 | 
			
		||||
    const int xi = S - 1 - pix_idx % S;
 | 
			
		||||
 | 
			
		||||
@ -260,23 +261,20 @@ __global__ void RasterizePointsCoarseCudaKernel(
 | 
			
		||||
        // Get y extent for the bin. PixToNdc gives us the location of
 | 
			
		||||
        // the center of each pixel, so we need to add/subtract a half
 | 
			
		||||
        // pixel to get the true extent of the bin.
 | 
			
		||||
        // Reverse ordering of Y axis so that +Y is upwards in the image.
 | 
			
		||||
        const int yidx = num_bins - by;
 | 
			
		||||
        const float bin_y_max = PixToNdc(yidx * bin_size - 1, S) + half_pix;
 | 
			
		||||
        const float bin_y_min = PixToNdc((yidx - 1) * bin_size, S) - half_pix;
 | 
			
		||||
        const float by0 = PixToNdc(by * bin_size, S) - half_pix;
 | 
			
		||||
        const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix;
 | 
			
		||||
        const bool y_overlap = (py0 <= by1) && (by0 <= py1);
 | 
			
		||||
 | 
			
		||||
        const bool y_overlap = (py0 <= bin_y_max) && (bin_y_min <= py1);
 | 
			
		||||
        if (!y_overlap) {
 | 
			
		||||
          continue;
 | 
			
		||||
        }
 | 
			
		||||
        for (int bx = 0; bx < num_bins; ++bx) {
 | 
			
		||||
          // Get x extent for the bin; again we need to adjust the
 | 
			
		||||
          // output of PixToNdc by half a pixel.
 | 
			
		||||
          // Reverse ordering of x axis so that +X is left.
 | 
			
		||||
          const int xidx = num_bins - bx;
 | 
			
		||||
          const float bin_x_max = PixToNdc(xidx * bin_size - 1, S) + half_pix;
 | 
			
		||||
          const float bin_x_min = PixToNdc((xidx - 1) * bin_size, S) - half_pix;
 | 
			
		||||
          const bool x_overlap = (px0 <= bin_x_max) && (bin_x_min <= px1);
 | 
			
		||||
          const float bx0 = PixToNdc(bx * bin_size, S) - half_pix;
 | 
			
		||||
          const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix;
 | 
			
		||||
          const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
 | 
			
		||||
 | 
			
		||||
          if (x_overlap) {
 | 
			
		||||
            binmask.set(by, bx, p);
 | 
			
		||||
          }
 | 
			
		||||
@ -330,6 +328,7 @@ torch::Tensor RasterizePointsCoarseCuda(
 | 
			
		||||
  const int N = num_points_per_cloud.size(0);
 | 
			
		||||
  const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
 | 
			
		||||
  const int M = max_points_per_bin;
 | 
			
		||||
 | 
			
		||||
  if (points.ndimension() != 2 || points.size(1) != 3) {
 | 
			
		||||
    AT_ERROR("points must have dimensions (num_points, 3)");
 | 
			
		||||
  }
 | 
			
		||||
@ -346,6 +345,7 @@ torch::Tensor RasterizePointsCoarseCuda(
 | 
			
		||||
  const size_t shared_size = num_bins * num_bins * chunk_size / 8;
 | 
			
		||||
  const size_t blocks = 64;
 | 
			
		||||
  const size_t threads = 512;
 | 
			
		||||
 | 
			
		||||
  RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
 | 
			
		||||
      points.contiguous().data_ptr<float>(),
 | 
			
		||||
      cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
 | 
			
		||||
@ -372,7 +372,7 @@ __global__ void RasterizePointsFineCudaKernel(
 | 
			
		||||
    const float radius,
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
    const int N,
 | 
			
		||||
    const int B,
 | 
			
		||||
    const int B, // num_bins
 | 
			
		||||
    const int M,
 | 
			
		||||
    const int S,
 | 
			
		||||
    const int K,
 | 
			
		||||
@ -397,19 +397,15 @@ __global__ void RasterizePointsFineCudaKernel(
 | 
			
		||||
    i %= B * bin_size * bin_size;
 | 
			
		||||
    const int bx = i / (bin_size * bin_size);
 | 
			
		||||
    i %= bin_size * bin_size;
 | 
			
		||||
 | 
			
		||||
    const int yi = i / bin_size + by * bin_size;
 | 
			
		||||
    const int xi = i % bin_size + bx * bin_size;
 | 
			
		||||
 | 
			
		||||
    if (yi >= S || xi >= S)
 | 
			
		||||
      continue;
 | 
			
		||||
 | 
			
		||||
    // Reverse ordering of the X and Y axis so that
 | 
			
		||||
    // in the image +Y is pointing up and +X is pointing left.
 | 
			
		||||
    const int yidx = S - 1 - yi;
 | 
			
		||||
    const int xidx = S - 1 - xi;
 | 
			
		||||
 | 
			
		||||
    const float xf = PixToNdc(xidx, S);
 | 
			
		||||
    const float yf = PixToNdc(yidx, S);
 | 
			
		||||
    const float xf = PixToNdc(xi, S);
 | 
			
		||||
    const float yf = PixToNdc(yi, S);
 | 
			
		||||
 | 
			
		||||
    // This part looks like the naive rasterization kernel, except we use
 | 
			
		||||
    // bin_points to only look at a subset of points already known to fall
 | 
			
		||||
@ -431,7 +427,13 @@ __global__ void RasterizePointsFineCudaKernel(
 | 
			
		||||
    // Now we've looked at all the points for this bin, so we can write
 | 
			
		||||
    // output for the current pixel.
 | 
			
		||||
    BubbleSort(q, q_size);
 | 
			
		||||
    const int pix_idx = n * S * S * K + yi * S * K + xi * K;
 | 
			
		||||
 | 
			
		||||
    // Reverse ordering of the X and Y axis as the camera coordinates
 | 
			
		||||
    // assume that +Y is pointing up and +X is pointing left.
 | 
			
		||||
    const int yidx = S - 1 - yi;
 | 
			
		||||
    const int xidx = S - 1 - xi;
 | 
			
		||||
 | 
			
		||||
    const int pix_idx = n * S * S * K + yidx * S * K + xidx * K;
 | 
			
		||||
    for (int k = 0; k < q_size; ++k) {
 | 
			
		||||
      point_idxs[pix_idx + k] = q[k].idx;
 | 
			
		||||
      zbuf[pix_idx + k] = q[k].z;
 | 
			
		||||
@ -448,7 +450,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
 | 
			
		||||
    const int bin_size,
 | 
			
		||||
    const int points_per_pixel) {
 | 
			
		||||
  const int N = bin_points.size(0);
 | 
			
		||||
  const int B = bin_points.size(1);
 | 
			
		||||
  const int B = bin_points.size(1); // num_bins
 | 
			
		||||
  const int M = bin_points.size(3);
 | 
			
		||||
  const int S = image_size;
 | 
			
		||||
  const int K = points_per_pixel;
 | 
			
		||||
 | 
			
		||||
@ -125,13 +125,13 @@ torch::Tensor RasterizePointsCoarseCpu(
 | 
			
		||||
    const int point_stop_idx =
 | 
			
		||||
        (point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
 | 
			
		||||
 | 
			
		||||
    float bin_y_max = 1.0f;
 | 
			
		||||
    float bin_y_min = bin_y_max - bin_width;
 | 
			
		||||
    float bin_y_min = -1.0f;
 | 
			
		||||
    float bin_y_max = bin_y_min + bin_width;
 | 
			
		||||
 | 
			
		||||
    // Iterate through the horizontal bins from top to bottom.
 | 
			
		||||
    for (int by = 0; by < B; by++) {
 | 
			
		||||
      float bin_x_max = 1.0f;
 | 
			
		||||
      float bin_x_min = bin_x_max - bin_width;
 | 
			
		||||
      float bin_x_min = -1.0f;
 | 
			
		||||
      float bin_x_max = bin_x_min + bin_width;
 | 
			
		||||
 | 
			
		||||
      // Iterate through bins on this horizontal line, left to right.
 | 
			
		||||
      for (int bx = 0; bx < B; bx++) {
 | 
			
		||||
@ -166,13 +166,13 @@ torch::Tensor RasterizePointsCoarseCpu(
 | 
			
		||||
        // Record the number of points found in this bin
 | 
			
		||||
        points_per_bin_a[n][by][bx] = points_hit;
 | 
			
		||||
 | 
			
		||||
        // Shift the bin to the left for the next loop iteration.
 | 
			
		||||
        bin_x_max = bin_x_min;
 | 
			
		||||
        bin_x_min = bin_x_min - bin_width;
 | 
			
		||||
        // Shift the bin to the right for the next loop iteration
 | 
			
		||||
        bin_x_min = bin_x_max;
 | 
			
		||||
        bin_x_max = bin_x_min + bin_width;
 | 
			
		||||
      }
 | 
			
		||||
      // Shift the bin down for the next loop iteration.
 | 
			
		||||
      bin_y_max = bin_y_min;
 | 
			
		||||
      bin_y_min = bin_y_min - bin_width;
 | 
			
		||||
      // Shift the bin down for the next loop iteration
 | 
			
		||||
      bin_y_min = bin_y_max;
 | 
			
		||||
      bin_y_max = bin_y_min + bin_width;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return bin_points;
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,7 @@ from pytorch3d import _C
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO make the epsilon user configurable
 | 
			
		||||
kEpsilon = 1e-30
 | 
			
		||||
kEpsilon = 1e-8
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def rasterize_meshes(
 | 
			
		||||
 | 
			
		||||
@ -19,12 +19,28 @@ class PointFragments(NamedTuple):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Class to store the point rasterization params with defaults
 | 
			
		||||
class PointsRasterizationSettings(NamedTuple):
 | 
			
		||||
    image_size: int = 256
 | 
			
		||||
    radius: float = 0.01
 | 
			
		||||
    points_per_pixel: int = 8
 | 
			
		||||
    bin_size: Optional[int] = None
 | 
			
		||||
    max_points_per_bin: Optional[int] = None
 | 
			
		||||
class PointsRasterizationSettings:
 | 
			
		||||
    __slots__ = [
 | 
			
		||||
        "image_size",
 | 
			
		||||
        "radius",
 | 
			
		||||
        "points_per_pixel",
 | 
			
		||||
        "bin_size",
 | 
			
		||||
        "max_points_per_bin",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        image_size: int = 256,
 | 
			
		||||
        radius: float = 0.01,
 | 
			
		||||
        points_per_pixel: int = 8,
 | 
			
		||||
        bin_size: Optional[int] = None,
 | 
			
		||||
        max_points_per_bin: Optional[int] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self.image_size = image_size
 | 
			
		||||
        self.radius = radius
 | 
			
		||||
        self.points_per_pixel = points_per_pixel
 | 
			
		||||
        self.bin_size = bin_size
 | 
			
		||||
        self.max_points_per_bin = max_points_per_bin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PointsRasterizer(nn.Module):
 | 
			
		||||
 | 
			
		||||
@ -1,10 +1,20 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Callable, Optional, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_rgb_image(filename: str, data_dir: Union[str, Path]):
 | 
			
		||||
    filepath = data_dir / filename
 | 
			
		||||
    with Image.open(filepath) as raw_image:
 | 
			
		||||
        image = torch.from_numpy(np.array(raw_image) / 255.0)
 | 
			
		||||
    image = image.to(dtype=torch.float32)
 | 
			
		||||
    return image[..., :3]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TensorOrArray = Union[torch.Tensor, np.ndarray]
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_bridge_pointcloud.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_bridge_pointcloud.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 74 KiB  | 
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_simple_pointcloud_sphere.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_simple_pointcloud_sphere.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 2.1 KiB  | 
@ -896,10 +896,10 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            torch.ones((1, 2, 2, max_faces_per_bin), dtype=torch.int32, device=device)
 | 
			
		||||
            * -1
 | 
			
		||||
        )
 | 
			
		||||
        bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
 | 
			
		||||
        bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([1, 2])
 | 
			
		||||
        bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([0, 1])
 | 
			
		||||
        bin_faces_expected[0, 1, 1, 0] = torch.tensor([1])
 | 
			
		||||
        bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([1, 2])
 | 
			
		||||
        bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([0, 1])
 | 
			
		||||
        bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
 | 
			
		||||
 | 
			
		||||
        # +Y up, +X left, +Z in
 | 
			
		||||
        bin_faces = _C._rasterize_meshes_coarse(
 | 
			
		||||
@ -911,7 +911,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            bin_size,
 | 
			
		||||
            max_faces_per_bin,
 | 
			
		||||
        )
 | 
			
		||||
        # Flip x and y axis of output before comparing to expected
 | 
			
		||||
 | 
			
		||||
        bin_faces_same = (bin_faces.squeeze() == bin_faces_expected).all()
 | 
			
		||||
        self.assertTrue(bin_faces_same.item() == 1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -434,23 +434,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def _test_coarse_rasterize(self, device):
 | 
			
		||||
        #
 | 
			
		||||
        #  Note that +Y is up and +X is left in the diagram below.
 | 
			
		||||
        #
 | 
			
		||||
        #  (4)              |2
 | 
			
		||||
        #           |2                  (4)
 | 
			
		||||
        #           |
 | 
			
		||||
        #           |
 | 
			
		||||
        #           |
 | 
			
		||||
        #           |1
 | 
			
		||||
        #           |
 | 
			
		||||
        #             (1)   |
 | 
			
		||||
        #                   | (2)
 | 
			
		||||
        # ____________(0)__(5)___________________
 | 
			
		||||
        # 2        1        |          -1      -2
 | 
			
		||||
        #           |    (1)
 | 
			
		||||
        #        (2)|
 | 
			
		||||
        # _________(5)___(0)_______________
 | 
			
		||||
        # -1        |           1         2
 | 
			
		||||
        #           |
 | 
			
		||||
        #       (3)         |
 | 
			
		||||
        #           |            (3)
 | 
			
		||||
        #           |
 | 
			
		||||
        #           |-1
 | 
			
		||||
        #                   |
 | 
			
		||||
        #
 | 
			
		||||
        # Locations of the points are shown by o. The screen bounding box
 | 
			
		||||
        # is between [-1, 1] in both the x and y directions.
 | 
			
		||||
@ -486,9 +484,9 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        # fit in one chunk. This will the the case for this small example, but
 | 
			
		||||
        # to properly exercise coordianted writes among multiple chunks we need
 | 
			
		||||
        # to use a bigger test case.
 | 
			
		||||
        bin_points_expected[0, 1, 0, :2] = torch.tensor([0, 3])
 | 
			
		||||
        bin_points_expected[0, 0, 1, 0] = torch.tensor([2])
 | 
			
		||||
        bin_points_expected[0, 0, 0, :2] = torch.tensor([0, 1])
 | 
			
		||||
        bin_points_expected[0, 0, 1, :2] = torch.tensor([0, 3])
 | 
			
		||||
        bin_points_expected[0, 1, 0, 0] = torch.tensor([2])
 | 
			
		||||
        bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
 | 
			
		||||
 | 
			
		||||
        pointclouds = Pointclouds(points=[points])
 | 
			
		||||
        args = (
 | 
			
		||||
@ -502,4 +500,5 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        bin_points = _C._rasterize_points_coarse(*args)
 | 
			
		||||
        bin_points_same = (bin_points == bin_points_expected).all()
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(bin_points_same.item() == 1)
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin, load_rgb_image
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.io import load_objs_as_meshes
 | 
			
		||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
 | 
			
		||||
@ -35,15 +36,7 @@ DEBUG = False
 | 
			
		||||
DATA_DIR = Path(__file__).resolve().parent / "data"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_rgb_image(filename, data_dir=DATA_DIR):
 | 
			
		||||
    filepath = data_dir / filename
 | 
			
		||||
    with Image.open(filepath) as raw_image:
 | 
			
		||||
        image = torch.from_numpy(np.array(raw_image) / 255.0)
 | 
			
		||||
    image = image.to(dtype=torch.float32)
 | 
			
		||||
    return image[..., :3]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_simple_sphere(self, elevated_camera=False):
 | 
			
		||||
        """
 | 
			
		||||
        Test output of phong and gouraud shading matches a reference image using
 | 
			
		||||
@ -81,7 +74,7 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
        lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
 | 
			
		||||
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
        rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
 | 
			
		||||
 | 
			
		||||
@ -96,14 +89,14 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
            renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
 | 
			
		||||
            images = renderer(sphere_mesh)
 | 
			
		||||
            filename = "simple_sphere_light_%s%s.png" % (name, postfix)
 | 
			
		||||
            image_ref = load_rgb_image("test_%s" % filename)
 | 
			
		||||
            image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
 | 
			
		||||
            rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                filename = "DEBUG_" % filename
 | 
			
		||||
                filename = "DEBUG_%s" % filename
 | 
			
		||||
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / filename
 | 
			
		||||
                )
 | 
			
		||||
            self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
 | 
			
		||||
            self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
        ########################################################
 | 
			
		||||
        # Move the light to the +z axis in world space so it is
 | 
			
		||||
@ -124,8 +117,10 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
 | 
			
		||||
        self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
 | 
			
		||||
        image_ref_phong_dark = load_rgb_image(
 | 
			
		||||
            "test_simple_sphere_dark%s.png" % postfix, DATA_DIR
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(rgb, image_ref_phong_dark, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_simple_sphere_elevated_camera(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -160,7 +155,7 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
        R, T = look_at_view_transform(dist, elev, azim)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init shader settings
 | 
			
		||||
@ -179,10 +174,12 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
            shader = shader_init(lights=lights, cameras=cameras, materials=materials)
 | 
			
		||||
            renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
 | 
			
		||||
            images = renderer(sphere_meshes)
 | 
			
		||||
            image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
 | 
			
		||||
            image_ref = load_rgb_image(
 | 
			
		||||
                "test_simple_sphere_light_%s.png" % name, DATA_DIR
 | 
			
		||||
            )
 | 
			
		||||
            for i in range(batch_size):
 | 
			
		||||
                rgb = images[i, ..., :3].squeeze().cpu()
 | 
			
		||||
                self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
 | 
			
		||||
                self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_silhouette_with_grad(self):
 | 
			
		||||
        """
 | 
			
		||||
@ -200,7 +197,6 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
            image_size=512,
 | 
			
		||||
            blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
 | 
			
		||||
            faces_per_pixel=80,
 | 
			
		||||
            bin_size=0,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init rasterizer settings
 | 
			
		||||
@ -222,7 +218,7 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
        with Image.open(image_ref_filename) as raw_image_ref:
 | 
			
		||||
            image_ref = torch.from_numpy(np.array(raw_image_ref))
 | 
			
		||||
        image_ref = image_ref.to(dtype=torch.float32) / 255.0
 | 
			
		||||
        self.assertTrue(torch.allclose(alpha, image_ref, atol=0.055))
 | 
			
		||||
        self.assertClose(alpha, image_ref, atol=0.055)
 | 
			
		||||
 | 
			
		||||
        # Check grad exist
 | 
			
		||||
        verts.requires_grad = True
 | 
			
		||||
@ -237,8 +233,8 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
        The pupils in the eyes of the cow should always be looking to the left.
 | 
			
		||||
        """
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
 | 
			
		||||
        obj_filename = DATA_DIR / "cow_mesh/cow.obj"
 | 
			
		||||
        obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
 | 
			
		||||
        obj_filename = obj_dir / "cow_mesh/cow.obj"
 | 
			
		||||
 | 
			
		||||
        # Load mesh + texture
 | 
			
		||||
        mesh = load_objs_as_meshes([obj_filename], device=device)
 | 
			
		||||
@ -247,7 +243,7 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
        R, T = look_at_view_transform(2.7, 0, 0)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Init shader settings
 | 
			
		||||
@ -265,11 +261,15 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
                lights=lights, cameras=cameras, materials=materials
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        images = renderer(mesh)
 | 
			
		||||
        rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        image_ref = load_rgb_image("test_texture_map_back.png")
 | 
			
		||||
        image_ref = load_rgb_image("test_texture_map_back.png", DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        for bin_size in [0, None]:
 | 
			
		||||
            # Check both naive and coarse to fine produce the same output.
 | 
			
		||||
            renderer.rasterizer.raster_settings.bin_size = bin_size
 | 
			
		||||
            images = renderer(mesh)
 | 
			
		||||
            rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
@ -299,17 +299,28 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
        # Move light to the front of the cow in world space
 | 
			
		||||
        lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
 | 
			
		||||
        images = renderer(mesh, cameras=cameras, lights=lights)
 | 
			
		||||
        rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        image_ref = load_rgb_image("test_texture_map_front.png")
 | 
			
		||||
        image_ref = load_rgb_image("test_texture_map_front.png", DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        for bin_size in [0, None]:
 | 
			
		||||
            # Check both naive and coarse to fine produce the same output.
 | 
			
		||||
            renderer.rasterizer.raster_settings.bin_size = bin_size
 | 
			
		||||
 | 
			
		||||
            images = renderer(mesh, cameras=cameras, lights=lights)
 | 
			
		||||
            rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / "DEBUG_texture_map_front.png"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            # NOTE some pixels can be flaky and will not lead to
 | 
			
		||||
            # `cond1` being true. Add `cond2` and check `cond1 or cond2`
 | 
			
		||||
            cond1 = torch.allclose(rgb, image_ref, atol=0.05)
 | 
			
		||||
            cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
 | 
			
		||||
            self.assertTrue(cond1 or cond2)
 | 
			
		||||
 | 
			
		||||
        #################################
 | 
			
		||||
        # Add blurring to rasterization
 | 
			
		||||
        #################################
 | 
			
		||||
@ -320,9 +331,15 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
            image_size=512,
 | 
			
		||||
            blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
 | 
			
		||||
            faces_per_pixel=100,
 | 
			
		||||
            bin_size=0,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        image_ref = load_rgb_image("test_blurry_textured_rendering.png", DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        for bin_size in [0, None]:
 | 
			
		||||
            # Check both naive and coarse to fine produce the same output.
 | 
			
		||||
            renderer.rasterizer.raster_settings.bin_size = bin_size
 | 
			
		||||
 | 
			
		||||
            images = renderer(
 | 
			
		||||
                mesh.clone(),
 | 
			
		||||
                cameras=cameras,
 | 
			
		||||
@ -331,12 +348,9 @@ class TestRenderingMeshes(unittest.TestCase):
 | 
			
		||||
            )
 | 
			
		||||
            rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        image_ref = load_rgb_image("test_blurry_textured_rendering.png")
 | 
			
		||||
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / "DEBUG_blurry_textured_rendering.png"
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
 | 
			
		||||
            self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
							
								
								
									
										173
									
								
								tests/test_render_points.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								tests/test_render_points.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,173 @@
 | 
			
		||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
Sanity checks for output images from the pointcloud renderer.
 | 
			
		||||
"""
 | 
			
		||||
import unittest
 | 
			
		||||
import warnings
 | 
			
		||||
from os import path
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from common_testing import TestCaseMixin, load_rgb_image
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.renderer.cameras import (
 | 
			
		||||
    OpenGLOrthographicCameras,
 | 
			
		||||
    OpenGLPerspectiveCameras,
 | 
			
		||||
    look_at_view_transform,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.points import (
 | 
			
		||||
    AlphaCompositor,
 | 
			
		||||
    NormWeightedCompositor,
 | 
			
		||||
    PointsRasterizationSettings,
 | 
			
		||||
    PointsRasterizer,
 | 
			
		||||
    PointsRenderer,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
from pytorch3d.utils.ico_sphere import ico_sphere
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# If DEBUG=True, save out images generated in the tests for debugging.
 | 
			
		||||
# All saved images have prefix DEBUG_
 | 
			
		||||
DEBUG = False
 | 
			
		||||
DATA_DIR = Path(__file__).resolve().parent / "data"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestRenderPoints(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_simple_sphere(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        sphere_mesh = ico_sphere(1, device)
 | 
			
		||||
        verts_padded = sphere_mesh.verts_padded()
 | 
			
		||||
        # Shift vertices to check coordinate frames are correct.
 | 
			
		||||
        verts_padded[..., 1] += 0.2
 | 
			
		||||
        verts_padded[..., 0] += 0.2
 | 
			
		||||
        pointclouds = Pointclouds(
 | 
			
		||||
            points=verts_padded, features=torch.ones_like(verts_padded)
 | 
			
		||||
        )
 | 
			
		||||
        R, T = look_at_view_transform(2.7, 0.0, 0.0)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
        raster_settings = PointsRasterizationSettings(
 | 
			
		||||
            image_size=256, radius=5e-2, points_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
        rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
 | 
			
		||||
        compositor = NormWeightedCompositor()
 | 
			
		||||
        renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        filename = "simple_pointcloud_sphere.png"
 | 
			
		||||
        image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        for bin_size in [0, None]:
 | 
			
		||||
            # Check both naive and coarse to fine produce the same output.
 | 
			
		||||
            renderer.rasterizer.raster_settings.bin_size = bin_size
 | 
			
		||||
            images = renderer(pointclouds)
 | 
			
		||||
            rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                filename = "DEBUG_%s" % filename
 | 
			
		||||
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / filename
 | 
			
		||||
                )
 | 
			
		||||
            self.assertClose(rgb, image_ref)
 | 
			
		||||
 | 
			
		||||
    def test_pointcloud_with_features(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
 | 
			
		||||
        pointcloud_filename = file_dir / "PittsburghBridge/pointcloud.npz"
 | 
			
		||||
 | 
			
		||||
        # Note, this file is too large to check in to the repo.
 | 
			
		||||
        # Download the file to run the test locally.
 | 
			
		||||
        if not path.exists(pointcloud_filename):
 | 
			
		||||
            url = "https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz"
 | 
			
		||||
            msg = (
 | 
			
		||||
                "pointcloud.npz not found, download from %s, save it at the path %s, and rerun"
 | 
			
		||||
                % (url, pointcloud_filename)
 | 
			
		||||
            )
 | 
			
		||||
            warnings.warn(msg)
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        # Load point cloud
 | 
			
		||||
        pointcloud = np.load(pointcloud_filename)
 | 
			
		||||
        verts = torch.Tensor(pointcloud["verts"]).to(device)
 | 
			
		||||
        rgb_feats = torch.Tensor(pointcloud["rgb"]).to(device)
 | 
			
		||||
 | 
			
		||||
        verts.requires_grad = True
 | 
			
		||||
        rgb_feats.requires_grad = True
 | 
			
		||||
        point_cloud = Pointclouds(points=[verts], features=[rgb_feats])
 | 
			
		||||
 | 
			
		||||
        R, T = look_at_view_transform(20, 10, 0)
 | 
			
		||||
        cameras = OpenGLOrthographicCameras(device=device, R=R, T=T, znear=0.01)
 | 
			
		||||
 | 
			
		||||
        raster_settings = PointsRasterizationSettings(
 | 
			
		||||
            # Set image_size so it is not a multiple of 16 (min bin_size)
 | 
			
		||||
            # in order to confirm that there are no errors in coarse rasterization.
 | 
			
		||||
            image_size=500,
 | 
			
		||||
            radius=0.003,
 | 
			
		||||
            points_per_pixel=10,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        renderer = PointsRenderer(
 | 
			
		||||
            rasterizer=PointsRasterizer(
 | 
			
		||||
                cameras=cameras, raster_settings=raster_settings
 | 
			
		||||
            ),
 | 
			
		||||
            compositor=AlphaCompositor(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        images = renderer(point_cloud)
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        filename = "bridge_pointcloud.png"
 | 
			
		||||
        image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        for bin_size in [0, None]:
 | 
			
		||||
            # Check both naive and coarse to fine produce the same output.
 | 
			
		||||
            renderer.rasterizer.raster_settings.bin_size = bin_size
 | 
			
		||||
            images = renderer(point_cloud)
 | 
			
		||||
            rgb = images[0, ..., :3].squeeze().cpu()
 | 
			
		||||
            if DEBUG:
 | 
			
		||||
                filename = "DEBUG_%s" % filename
 | 
			
		||||
                Image.fromarray((rgb.detach().numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / filename
 | 
			
		||||
                )
 | 
			
		||||
            self.assertClose(rgb, image_ref, atol=0.015)
 | 
			
		||||
 | 
			
		||||
        # Check grad exists.
 | 
			
		||||
        grad_images = torch.randn_like(images)
 | 
			
		||||
        images.backward(grad_images)
 | 
			
		||||
        self.assertIsNotNone(verts.grad)
 | 
			
		||||
        self.assertIsNotNone(rgb_feats.grad)
 | 
			
		||||
 | 
			
		||||
    def test_simple_sphere_batched(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        sphere_mesh = ico_sphere(1, device)
 | 
			
		||||
        verts_padded = sphere_mesh.verts_padded()
 | 
			
		||||
        verts_padded[..., 1] += 0.2
 | 
			
		||||
        verts_padded[..., 0] += 0.2
 | 
			
		||||
        pointclouds = Pointclouds(
 | 
			
		||||
            points=verts_padded, features=torch.ones_like(verts_padded)
 | 
			
		||||
        )
 | 
			
		||||
        batch_size = 20
 | 
			
		||||
        pointclouds = pointclouds.extend(batch_size)
 | 
			
		||||
        R, T = look_at_view_transform(2.7, 0.0, 0.0)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
        raster_settings = PointsRasterizationSettings(
 | 
			
		||||
            image_size=256, radius=5e-2, points_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
        rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
 | 
			
		||||
        compositor = NormWeightedCompositor()
 | 
			
		||||
        renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        filename = "simple_pointcloud_sphere.png"
 | 
			
		||||
        image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
 | 
			
		||||
 | 
			
		||||
        images = renderer(pointclouds)
 | 
			
		||||
        for i in range(batch_size):
 | 
			
		||||
            rgb = images[i, ..., :3].squeeze().cpu()
 | 
			
		||||
            if i == 0 and DEBUG:
 | 
			
		||||
                filename = "DEBUG_%s" % filename
 | 
			
		||||
                Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                    DATA_DIR / filename
 | 
			
		||||
                )
 | 
			
		||||
            self.assertClose(rgb, image_ref)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user