mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
coarse rasterization bug fix
Summary: Fix a bug which resulted in a rendering artifacts if the image size was not a multiple of 16. Fix: Revert coarse rasterization to original implementation and only update fine rasterization to reverse the ordering of Y and X axis. This is much simpler than the previous approach! Additional changes: - updated mesh rendering end-end tests to check outputs from both naive and coarse to fine rasterization. - added pointcloud rendering end-end tests Reviewed By: gkioxari Differential Revision: D21102725 fbshipit-source-id: 2e7e1b013dd6dd12b3a00b79eb8167deddb2e89a
This commit is contained in:
parent
1e4749602d
commit
9ef1ee8455
Binary file not shown.
Before Width: | Height: | Size: 62 KiB After Width: | Height: | Size: 64 KiB |
File diff suppressed because one or more lines are too long
@ -556,18 +556,16 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
|||||||
// PixToNdc gives the location of the center of each pixel, so we
|
// PixToNdc gives the location of the center of each pixel, so we
|
||||||
// need to add/subtract a half pixel to get the true extent of the bin.
|
// need to add/subtract a half pixel to get the true extent of the bin.
|
||||||
// Reverse ordering of Y axis so that +Y is upwards in the image.
|
// Reverse ordering of Y axis so that +Y is upwards in the image.
|
||||||
const int yidx = num_bins - by;
|
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
|
||||||
const float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
|
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
|
||||||
const float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;
|
|
||||||
|
|
||||||
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
|
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
|
||||||
|
|
||||||
for (int bx = 0; bx < num_bins; ++bx) {
|
for (int bx = 0; bx < num_bins; ++bx) {
|
||||||
// X coordinate of the left and right of the bin.
|
// X coordinate of the left and right of the bin.
|
||||||
// Reverse ordering of x axis so that +X is left.
|
// Reverse ordering of x axis so that +X is left.
|
||||||
const int xidx = num_bins - bx;
|
const float bin_x_max =
|
||||||
const float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
|
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
|
||||||
const float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;
|
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
|
||||||
|
|
||||||
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
|
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
|
||||||
if (y_overlap && x_overlap) {
|
if (y_overlap && x_overlap) {
|
||||||
@ -629,6 +627,7 @@ torch::Tensor RasterizeMeshesCoarseCuda(
|
|||||||
const int N = num_faces_per_mesh.size(0);
|
const int N = num_faces_per_mesh.size(0);
|
||||||
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
|
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
|
||||||
const int M = max_faces_per_bin;
|
const int M = max_faces_per_bin;
|
||||||
|
|
||||||
if (num_bins >= 22) {
|
if (num_bins >= 22) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "Got " << num_bins << "; that's too many!";
|
ss << "Got " << num_bins << "; that's too many!";
|
||||||
@ -702,13 +701,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
|||||||
if (yi >= H || xi >= W)
|
if (yi >= H || xi >= W)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Reverse ordering of the X and Y axis so that
|
const float xf = PixToNdc(xi, W);
|
||||||
// in the image +Y is pointing up and +X is pointing left.
|
const float yf = PixToNdc(yi, H);
|
||||||
const int yidx = H - 1 - yi;
|
|
||||||
const int xidx = W - 1 - xi;
|
|
||||||
|
|
||||||
const float xf = PixToNdc(xidx, W);
|
|
||||||
const float yf = PixToNdc(yidx, H);
|
|
||||||
const float2 pxy = make_float2(xf, yf);
|
const float2 pxy = make_float2(xf, yf);
|
||||||
|
|
||||||
// This part looks like the naive rasterization kernel, except we use
|
// This part looks like the naive rasterization kernel, except we use
|
||||||
@ -743,7 +737,12 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
|||||||
// output for the current pixel.
|
// output for the current pixel.
|
||||||
// TODO: make sorting an option as only top k is needed, not sorted values.
|
// TODO: make sorting an option as only top k is needed, not sorted values.
|
||||||
BubbleSort(q, q_size);
|
BubbleSort(q, q_size);
|
||||||
const int pix_idx = n * H * W * K + yi * H * K + xi * K;
|
|
||||||
|
// Reverse ordering of the X and Y axis so that
|
||||||
|
// in the image +Y is pointing up and +X is pointing left.
|
||||||
|
const int yidx = H - 1 - yi;
|
||||||
|
const int xidx = W - 1 - xi;
|
||||||
|
const int pix_idx = n * H * W * K + yidx * H * K + xidx * K;
|
||||||
for (int k = 0; k < q_size; k++) {
|
for (int k = 0; k < q_size; k++) {
|
||||||
face_idxs[pix_idx + k] = q[k].idx;
|
face_idxs[pix_idx + k] = q[k].idx;
|
||||||
zbuf[pix_idx + k] = q[k].z;
|
zbuf[pix_idx + k] = q[k].z;
|
||||||
|
@ -430,13 +430,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
|||||||
const int face_stop_idx =
|
const int face_stop_idx =
|
||||||
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
|
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
|
||||||
|
|
||||||
float bin_y_max = 1.0f;
|
float bin_y_min = -1.0f;
|
||||||
float bin_y_min = bin_y_max - bin_width;
|
float bin_y_max = bin_y_min + bin_width;
|
||||||
|
|
||||||
// Iterate through the horizontal bins from top to bottom.
|
// Iterate through the horizontal bins from top to bottom.
|
||||||
for (int by = 0; by < BH; ++by) {
|
for (int by = 0; by < BH; ++by) {
|
||||||
float bin_x_max = 1.0f;
|
float bin_x_min = -1.0f;
|
||||||
float bin_x_min = bin_x_max - bin_width;
|
float bin_x_max = bin_x_min + bin_width;
|
||||||
|
|
||||||
// Iterate through bins on this horizontal line, left to right.
|
// Iterate through bins on this horizontal line, left to right.
|
||||||
for (int bx = 0; bx < BW; ++bx) {
|
for (int bx = 0; bx < BW; ++bx) {
|
||||||
@ -473,13 +473,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shift the bin to the left for the next loop iteration.
|
// Shift the bin to the right for the next loop iteration
|
||||||
bin_x_max = bin_x_min;
|
bin_x_min = bin_x_max;
|
||||||
bin_x_min = bin_x_min - bin_width;
|
bin_x_max = bin_x_min + bin_width;
|
||||||
}
|
}
|
||||||
// Shift the bin down for the next loop iteration.
|
// Shift the bin down for the next loop iteration
|
||||||
bin_y_max = bin_y_min;
|
bin_y_min = bin_y_max;
|
||||||
bin_y_min = bin_y_min - bin_width;
|
bin_y_max = bin_y_min + bin_width;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return bin_faces;
|
return bin_faces;
|
||||||
|
@ -95,7 +95,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
|
|||||||
const int n = i / (S * S); // Batch index
|
const int n = i / (S * S); // Batch index
|
||||||
const int pix_idx = i % (S * S);
|
const int pix_idx = i % (S * S);
|
||||||
|
|
||||||
// Reverse ordering of X and Y axes.
|
// Reverse ordering of the X and Y axis as the camera coordinates
|
||||||
|
// assume that +Y is pointing up and +X is pointing left.
|
||||||
const int yi = S - 1 - pix_idx / S;
|
const int yi = S - 1 - pix_idx / S;
|
||||||
const int xi = S - 1 - pix_idx % S;
|
const int xi = S - 1 - pix_idx % S;
|
||||||
|
|
||||||
@ -260,23 +261,20 @@ __global__ void RasterizePointsCoarseCudaKernel(
|
|||||||
// Get y extent for the bin. PixToNdc gives us the location of
|
// Get y extent for the bin. PixToNdc gives us the location of
|
||||||
// the center of each pixel, so we need to add/subtract a half
|
// the center of each pixel, so we need to add/subtract a half
|
||||||
// pixel to get the true extent of the bin.
|
// pixel to get the true extent of the bin.
|
||||||
// Reverse ordering of Y axis so that +Y is upwards in the image.
|
const float by0 = PixToNdc(by * bin_size, S) - half_pix;
|
||||||
const int yidx = num_bins - by;
|
const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix;
|
||||||
const float bin_y_max = PixToNdc(yidx * bin_size - 1, S) + half_pix;
|
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
|
||||||
const float bin_y_min = PixToNdc((yidx - 1) * bin_size, S) - half_pix;
|
|
||||||
|
|
||||||
const bool y_overlap = (py0 <= bin_y_max) && (bin_y_min <= py1);
|
|
||||||
if (!y_overlap) {
|
if (!y_overlap) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (int bx = 0; bx < num_bins; ++bx) {
|
for (int bx = 0; bx < num_bins; ++bx) {
|
||||||
// Get x extent for the bin; again we need to adjust the
|
// Get x extent for the bin; again we need to adjust the
|
||||||
// output of PixToNdc by half a pixel.
|
// output of PixToNdc by half a pixel.
|
||||||
// Reverse ordering of x axis so that +X is left.
|
const float bx0 = PixToNdc(bx * bin_size, S) - half_pix;
|
||||||
const int xidx = num_bins - bx;
|
const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix;
|
||||||
const float bin_x_max = PixToNdc(xidx * bin_size - 1, S) + half_pix;
|
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
|
||||||
const float bin_x_min = PixToNdc((xidx - 1) * bin_size, S) - half_pix;
|
|
||||||
const bool x_overlap = (px0 <= bin_x_max) && (bin_x_min <= px1);
|
|
||||||
if (x_overlap) {
|
if (x_overlap) {
|
||||||
binmask.set(by, bx, p);
|
binmask.set(by, bx, p);
|
||||||
}
|
}
|
||||||
@ -330,6 +328,7 @@ torch::Tensor RasterizePointsCoarseCuda(
|
|||||||
const int N = num_points_per_cloud.size(0);
|
const int N = num_points_per_cloud.size(0);
|
||||||
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
|
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
|
||||||
const int M = max_points_per_bin;
|
const int M = max_points_per_bin;
|
||||||
|
|
||||||
if (points.ndimension() != 2 || points.size(1) != 3) {
|
if (points.ndimension() != 2 || points.size(1) != 3) {
|
||||||
AT_ERROR("points must have dimensions (num_points, 3)");
|
AT_ERROR("points must have dimensions (num_points, 3)");
|
||||||
}
|
}
|
||||||
@ -346,6 +345,7 @@ torch::Tensor RasterizePointsCoarseCuda(
|
|||||||
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
|
||||||
const size_t blocks = 64;
|
const size_t blocks = 64;
|
||||||
const size_t threads = 512;
|
const size_t threads = 512;
|
||||||
|
|
||||||
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
|
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
|
||||||
points.contiguous().data_ptr<float>(),
|
points.contiguous().data_ptr<float>(),
|
||||||
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
|
||||||
@ -372,7 +372,7 @@ __global__ void RasterizePointsFineCudaKernel(
|
|||||||
const float radius,
|
const float radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int N,
|
const int N,
|
||||||
const int B,
|
const int B, // num_bins
|
||||||
const int M,
|
const int M,
|
||||||
const int S,
|
const int S,
|
||||||
const int K,
|
const int K,
|
||||||
@ -397,19 +397,15 @@ __global__ void RasterizePointsFineCudaKernel(
|
|||||||
i %= B * bin_size * bin_size;
|
i %= B * bin_size * bin_size;
|
||||||
const int bx = i / (bin_size * bin_size);
|
const int bx = i / (bin_size * bin_size);
|
||||||
i %= bin_size * bin_size;
|
i %= bin_size * bin_size;
|
||||||
|
|
||||||
const int yi = i / bin_size + by * bin_size;
|
const int yi = i / bin_size + by * bin_size;
|
||||||
const int xi = i % bin_size + bx * bin_size;
|
const int xi = i % bin_size + bx * bin_size;
|
||||||
|
|
||||||
if (yi >= S || xi >= S)
|
if (yi >= S || xi >= S)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
// Reverse ordering of the X and Y axis so that
|
const float xf = PixToNdc(xi, S);
|
||||||
// in the image +Y is pointing up and +X is pointing left.
|
const float yf = PixToNdc(yi, S);
|
||||||
const int yidx = S - 1 - yi;
|
|
||||||
const int xidx = S - 1 - xi;
|
|
||||||
|
|
||||||
const float xf = PixToNdc(xidx, S);
|
|
||||||
const float yf = PixToNdc(yidx, S);
|
|
||||||
|
|
||||||
// This part looks like the naive rasterization kernel, except we use
|
// This part looks like the naive rasterization kernel, except we use
|
||||||
// bin_points to only look at a subset of points already known to fall
|
// bin_points to only look at a subset of points already known to fall
|
||||||
@ -431,7 +427,13 @@ __global__ void RasterizePointsFineCudaKernel(
|
|||||||
// Now we've looked at all the points for this bin, so we can write
|
// Now we've looked at all the points for this bin, so we can write
|
||||||
// output for the current pixel.
|
// output for the current pixel.
|
||||||
BubbleSort(q, q_size);
|
BubbleSort(q, q_size);
|
||||||
const int pix_idx = n * S * S * K + yi * S * K + xi * K;
|
|
||||||
|
// Reverse ordering of the X and Y axis as the camera coordinates
|
||||||
|
// assume that +Y is pointing up and +X is pointing left.
|
||||||
|
const int yidx = S - 1 - yi;
|
||||||
|
const int xidx = S - 1 - xi;
|
||||||
|
|
||||||
|
const int pix_idx = n * S * S * K + yidx * S * K + xidx * K;
|
||||||
for (int k = 0; k < q_size; ++k) {
|
for (int k = 0; k < q_size; ++k) {
|
||||||
point_idxs[pix_idx + k] = q[k].idx;
|
point_idxs[pix_idx + k] = q[k].idx;
|
||||||
zbuf[pix_idx + k] = q[k].z;
|
zbuf[pix_idx + k] = q[k].z;
|
||||||
@ -448,7 +450,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
|
|||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int points_per_pixel) {
|
const int points_per_pixel) {
|
||||||
const int N = bin_points.size(0);
|
const int N = bin_points.size(0);
|
||||||
const int B = bin_points.size(1);
|
const int B = bin_points.size(1); // num_bins
|
||||||
const int M = bin_points.size(3);
|
const int M = bin_points.size(3);
|
||||||
const int S = image_size;
|
const int S = image_size;
|
||||||
const int K = points_per_pixel;
|
const int K = points_per_pixel;
|
||||||
|
@ -125,13 +125,13 @@ torch::Tensor RasterizePointsCoarseCpu(
|
|||||||
const int point_stop_idx =
|
const int point_stop_idx =
|
||||||
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
|
||||||
|
|
||||||
float bin_y_max = 1.0f;
|
float bin_y_min = -1.0f;
|
||||||
float bin_y_min = bin_y_max - bin_width;
|
float bin_y_max = bin_y_min + bin_width;
|
||||||
|
|
||||||
// Iterate through the horizontal bins from top to bottom.
|
// Iterate through the horizontal bins from top to bottom.
|
||||||
for (int by = 0; by < B; by++) {
|
for (int by = 0; by < B; by++) {
|
||||||
float bin_x_max = 1.0f;
|
float bin_x_min = -1.0f;
|
||||||
float bin_x_min = bin_x_max - bin_width;
|
float bin_x_max = bin_x_min + bin_width;
|
||||||
|
|
||||||
// Iterate through bins on this horizontal line, left to right.
|
// Iterate through bins on this horizontal line, left to right.
|
||||||
for (int bx = 0; bx < B; bx++) {
|
for (int bx = 0; bx < B; bx++) {
|
||||||
@ -166,13 +166,13 @@ torch::Tensor RasterizePointsCoarseCpu(
|
|||||||
// Record the number of points found in this bin
|
// Record the number of points found in this bin
|
||||||
points_per_bin_a[n][by][bx] = points_hit;
|
points_per_bin_a[n][by][bx] = points_hit;
|
||||||
|
|
||||||
// Shift the bin to the left for the next loop iteration.
|
// Shift the bin to the right for the next loop iteration
|
||||||
bin_x_max = bin_x_min;
|
bin_x_min = bin_x_max;
|
||||||
bin_x_min = bin_x_min - bin_width;
|
bin_x_max = bin_x_min + bin_width;
|
||||||
}
|
}
|
||||||
// Shift the bin down for the next loop iteration.
|
// Shift the bin down for the next loop iteration
|
||||||
bin_y_max = bin_y_min;
|
bin_y_min = bin_y_max;
|
||||||
bin_y_min = bin_y_min - bin_width;
|
bin_y_max = bin_y_min + bin_width;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return bin_points;
|
return bin_points;
|
||||||
|
@ -9,7 +9,7 @@ from pytorch3d import _C
|
|||||||
|
|
||||||
|
|
||||||
# TODO make the epsilon user configurable
|
# TODO make the epsilon user configurable
|
||||||
kEpsilon = 1e-30
|
kEpsilon = 1e-8
|
||||||
|
|
||||||
|
|
||||||
def rasterize_meshes(
|
def rasterize_meshes(
|
||||||
|
@ -19,12 +19,28 @@ class PointFragments(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
# Class to store the point rasterization params with defaults
|
# Class to store the point rasterization params with defaults
|
||||||
class PointsRasterizationSettings(NamedTuple):
|
class PointsRasterizationSettings:
|
||||||
image_size: int = 256
|
__slots__ = [
|
||||||
radius: float = 0.01
|
"image_size",
|
||||||
points_per_pixel: int = 8
|
"radius",
|
||||||
bin_size: Optional[int] = None
|
"points_per_pixel",
|
||||||
max_points_per_bin: Optional[int] = None
|
"bin_size",
|
||||||
|
"max_points_per_bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_size: int = 256,
|
||||||
|
radius: float = 0.01,
|
||||||
|
points_per_pixel: int = 8,
|
||||||
|
bin_size: Optional[int] = None,
|
||||||
|
max_points_per_bin: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.image_size = image_size
|
||||||
|
self.radius = radius
|
||||||
|
self.points_per_pixel = points_per_pixel
|
||||||
|
self.bin_size = bin_size
|
||||||
|
self.max_points_per_bin = max_points_per_bin
|
||||||
|
|
||||||
|
|
||||||
class PointsRasterizer(nn.Module):
|
class PointsRasterizer(nn.Module):
|
||||||
|
@ -1,10 +1,20 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def load_rgb_image(filename: str, data_dir: Union[str, Path]):
|
||||||
|
filepath = data_dir / filename
|
||||||
|
with Image.open(filepath) as raw_image:
|
||||||
|
image = torch.from_numpy(np.array(raw_image) / 255.0)
|
||||||
|
image = image.to(dtype=torch.float32)
|
||||||
|
return image[..., :3]
|
||||||
|
|
||||||
|
|
||||||
TensorOrArray = Union[torch.Tensor, np.ndarray]
|
TensorOrArray = Union[torch.Tensor, np.ndarray]
|
||||||
|
BIN
tests/data/test_bridge_pointcloud.png
Normal file
BIN
tests/data/test_bridge_pointcloud.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 74 KiB |
BIN
tests/data/test_simple_pointcloud_sphere.png
Normal file
BIN
tests/data/test_simple_pointcloud_sphere.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 2.1 KiB |
@ -896,10 +896,10 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
torch.ones((1, 2, 2, max_faces_per_bin), dtype=torch.int32, device=device)
|
torch.ones((1, 2, 2, max_faces_per_bin), dtype=torch.int32, device=device)
|
||||||
* -1
|
* -1
|
||||||
)
|
)
|
||||||
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
|
|
||||||
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([1, 2])
|
|
||||||
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([0, 1])
|
|
||||||
bin_faces_expected[0, 1, 1, 0] = torch.tensor([1])
|
bin_faces_expected[0, 1, 1, 0] = torch.tensor([1])
|
||||||
|
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([1, 2])
|
||||||
|
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([0, 1])
|
||||||
|
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
|
||||||
|
|
||||||
# +Y up, +X left, +Z in
|
# +Y up, +X left, +Z in
|
||||||
bin_faces = _C._rasterize_meshes_coarse(
|
bin_faces = _C._rasterize_meshes_coarse(
|
||||||
@ -911,7 +911,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
bin_size,
|
bin_size,
|
||||||
max_faces_per_bin,
|
max_faces_per_bin,
|
||||||
)
|
)
|
||||||
# Flip x and y axis of output before comparing to expected
|
|
||||||
bin_faces_same = (bin_faces.squeeze() == bin_faces_expected).all()
|
bin_faces_same = (bin_faces.squeeze() == bin_faces_expected).all()
|
||||||
self.assertTrue(bin_faces_same.item() == 1)
|
self.assertTrue(bin_faces_same.item() == 1)
|
||||||
|
|
||||||
|
@ -434,23 +434,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def _test_coarse_rasterize(self, device):
|
def _test_coarse_rasterize(self, device):
|
||||||
#
|
#
|
||||||
# Note that +Y is up and +X is left in the diagram below.
|
|
||||||
#
|
#
|
||||||
# (4) |2
|
# |2 (4)
|
||||||
# |
|
# |
|
||||||
# |
|
# |
|
||||||
# |
|
# |
|
||||||
# |1
|
# |1
|
||||||
# |
|
# |
|
||||||
# (1) |
|
# | (1)
|
||||||
# | (2)
|
# (2)|
|
||||||
# ____________(0)__(5)___________________
|
# _________(5)___(0)_______________
|
||||||
# 2 1 | -1 -2
|
# -1 | 1 2
|
||||||
# |
|
# |
|
||||||
# (3) |
|
# | (3)
|
||||||
# |
|
# |
|
||||||
# |-1
|
# |-1
|
||||||
# |
|
|
||||||
#
|
#
|
||||||
# Locations of the points are shown by o. The screen bounding box
|
# Locations of the points are shown by o. The screen bounding box
|
||||||
# is between [-1, 1] in both the x and y directions.
|
# is between [-1, 1] in both the x and y directions.
|
||||||
@ -486,9 +484,9 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
|||||||
# fit in one chunk. This will the the case for this small example, but
|
# fit in one chunk. This will the the case for this small example, but
|
||||||
# to properly exercise coordianted writes among multiple chunks we need
|
# to properly exercise coordianted writes among multiple chunks we need
|
||||||
# to use a bigger test case.
|
# to use a bigger test case.
|
||||||
bin_points_expected[0, 1, 0, :2] = torch.tensor([0, 3])
|
bin_points_expected[0, 0, 1, :2] = torch.tensor([0, 3])
|
||||||
bin_points_expected[0, 0, 1, 0] = torch.tensor([2])
|
bin_points_expected[0, 1, 0, 0] = torch.tensor([2])
|
||||||
bin_points_expected[0, 0, 0, :2] = torch.tensor([0, 1])
|
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
|
||||||
|
|
||||||
pointclouds = Pointclouds(points=[points])
|
pointclouds = Pointclouds(points=[points])
|
||||||
args = (
|
args = (
|
||||||
@ -502,4 +500,5 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
bin_points = _C._rasterize_points_coarse(*args)
|
bin_points = _C._rasterize_points_coarse(*args)
|
||||||
bin_points_same = (bin_points == bin_points_expected).all()
|
bin_points_same = (bin_points == bin_points_expected).all()
|
||||||
|
|
||||||
self.assertTrue(bin_points_same.item() == 1)
|
self.assertTrue(bin_points_same.item() == 1)
|
||||||
|
@ -9,6 +9,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from common_testing import TestCaseMixin, load_rgb_image
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.io import load_objs_as_meshes
|
from pytorch3d.io import load_objs_as_meshes
|
||||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
||||||
@ -35,15 +36,7 @@ DEBUG = False
|
|||||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||||
|
|
||||||
|
|
||||||
def load_rgb_image(filename, data_dir=DATA_DIR):
|
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||||
filepath = data_dir / filename
|
|
||||||
with Image.open(filepath) as raw_image:
|
|
||||||
image = torch.from_numpy(np.array(raw_image) / 255.0)
|
|
||||||
image = image.to(dtype=torch.float32)
|
|
||||||
return image[..., :3]
|
|
||||||
|
|
||||||
|
|
||||||
class TestRenderingMeshes(unittest.TestCase):
|
|
||||||
def test_simple_sphere(self, elevated_camera=False):
|
def test_simple_sphere(self, elevated_camera=False):
|
||||||
"""
|
"""
|
||||||
Test output of phong and gouraud shading matches a reference image using
|
Test output of phong and gouraud shading matches a reference image using
|
||||||
@ -81,7 +74,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||||
|
|
||||||
raster_settings = RasterizationSettings(
|
raster_settings = RasterizationSettings(
|
||||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||||
)
|
)
|
||||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
|
|
||||||
@ -96,14 +89,14 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||||
images = renderer(sphere_mesh)
|
images = renderer(sphere_mesh)
|
||||||
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
||||||
image_ref = load_rgb_image("test_%s" % filename)
|
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
filename = "DEBUG_" % filename
|
filename = "DEBUG_%s" % filename
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
DATA_DIR / filename
|
DATA_DIR / filename
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
self.assertClose(rgb, image_ref, atol=0.05)
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# Move the light to the +z axis in world space so it is
|
# Move the light to the +z axis in world space so it is
|
||||||
@ -124,8 +117,10 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load reference image
|
# Load reference image
|
||||||
image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
|
image_ref_phong_dark = load_rgb_image(
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
|
"test_simple_sphere_dark%s.png" % postfix, DATA_DIR
|
||||||
|
)
|
||||||
|
self.assertClose(rgb, image_ref_phong_dark, atol=0.05)
|
||||||
|
|
||||||
def test_simple_sphere_elevated_camera(self):
|
def test_simple_sphere_elevated_camera(self):
|
||||||
"""
|
"""
|
||||||
@ -160,7 +155,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
R, T = look_at_view_transform(dist, elev, azim)
|
R, T = look_at_view_transform(dist, elev, azim)
|
||||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
raster_settings = RasterizationSettings(
|
raster_settings = RasterizationSettings(
|
||||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init shader settings
|
# Init shader settings
|
||||||
@ -179,10 +174,12 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
||||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||||
images = renderer(sphere_meshes)
|
images = renderer(sphere_meshes)
|
||||||
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
|
image_ref = load_rgb_image(
|
||||||
|
"test_simple_sphere_light_%s.png" % name, DATA_DIR
|
||||||
|
)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
rgb = images[i, ..., :3].squeeze().cpu()
|
rgb = images[i, ..., :3].squeeze().cpu()
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
self.assertClose(rgb, image_ref, atol=0.05)
|
||||||
|
|
||||||
def test_silhouette_with_grad(self):
|
def test_silhouette_with_grad(self):
|
||||||
"""
|
"""
|
||||||
@ -200,7 +197,6 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
image_size=512,
|
image_size=512,
|
||||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
||||||
faces_per_pixel=80,
|
faces_per_pixel=80,
|
||||||
bin_size=0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init rasterizer settings
|
# Init rasterizer settings
|
||||||
@ -222,7 +218,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
with Image.open(image_ref_filename) as raw_image_ref:
|
with Image.open(image_ref_filename) as raw_image_ref:
|
||||||
image_ref = torch.from_numpy(np.array(raw_image_ref))
|
image_ref = torch.from_numpy(np.array(raw_image_ref))
|
||||||
image_ref = image_ref.to(dtype=torch.float32) / 255.0
|
image_ref = image_ref.to(dtype=torch.float32) / 255.0
|
||||||
self.assertTrue(torch.allclose(alpha, image_ref, atol=0.055))
|
self.assertClose(alpha, image_ref, atol=0.055)
|
||||||
|
|
||||||
# Check grad exist
|
# Check grad exist
|
||||||
verts.requires_grad = True
|
verts.requires_grad = True
|
||||||
@ -237,8 +233,8 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
The pupils in the eyes of the cow should always be looking to the left.
|
The pupils in the eyes of the cow should always be looking to the left.
|
||||||
"""
|
"""
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||||
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
|
obj_filename = obj_dir / "cow_mesh/cow.obj"
|
||||||
|
|
||||||
# Load mesh + texture
|
# Load mesh + texture
|
||||||
mesh = load_objs_as_meshes([obj_filename], device=device)
|
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||||
@ -247,7 +243,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
R, T = look_at_view_transform(2.7, 0, 0)
|
R, T = look_at_view_transform(2.7, 0, 0)
|
||||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
raster_settings = RasterizationSettings(
|
raster_settings = RasterizationSettings(
|
||||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init shader settings
|
# Init shader settings
|
||||||
@ -265,22 +261,26 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
lights=lights, cameras=cameras, materials=materials
|
lights=lights, cameras=cameras, materials=materials
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
images = renderer(mesh)
|
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
|
||||||
|
|
||||||
# Load reference image
|
# Load reference image
|
||||||
image_ref = load_rgb_image("test_texture_map_back.png")
|
image_ref = load_rgb_image("test_texture_map_back.png", DATA_DIR)
|
||||||
|
|
||||||
if DEBUG:
|
for bin_size in [0, None]:
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
# Check both naive and coarse to fine produce the same output.
|
||||||
DATA_DIR / "DEBUG_texture_map_back.png"
|
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||||
)
|
images = renderer(mesh)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
|
||||||
# NOTE some pixels can be flaky and will not lead to
|
if DEBUG:
|
||||||
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
DATA_DIR / "DEBUG_texture_map_back.png"
|
||||||
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
|
)
|
||||||
self.assertTrue(cond1 or cond2)
|
|
||||||
|
# NOTE some pixels can be flaky and will not lead to
|
||||||
|
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
|
||||||
|
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
||||||
|
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
|
||||||
|
self.assertTrue(cond1 or cond2)
|
||||||
|
|
||||||
# Check grad exists
|
# Check grad exists
|
||||||
[verts] = mesh.verts_list()
|
[verts] = mesh.verts_list()
|
||||||
@ -299,16 +299,27 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
# Move light to the front of the cow in world space
|
# Move light to the front of the cow in world space
|
||||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||||
images = renderer(mesh, cameras=cameras, lights=lights)
|
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
|
||||||
|
|
||||||
# Load reference image
|
# Load reference image
|
||||||
image_ref = load_rgb_image("test_texture_map_front.png")
|
image_ref = load_rgb_image("test_texture_map_front.png", DATA_DIR)
|
||||||
|
|
||||||
if DEBUG:
|
for bin_size in [0, None]:
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
# Check both naive and coarse to fine produce the same output.
|
||||||
DATA_DIR / "DEBUG_texture_map_front.png"
|
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||||
)
|
|
||||||
|
images = renderer(mesh, cameras=cameras, lights=lights)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / "DEBUG_texture_map_front.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE some pixels can be flaky and will not lead to
|
||||||
|
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
|
||||||
|
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
||||||
|
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
|
||||||
|
self.assertTrue(cond1 or cond2)
|
||||||
|
|
||||||
#################################
|
#################################
|
||||||
# Add blurring to rasterization
|
# Add blurring to rasterization
|
||||||
@ -320,23 +331,26 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
image_size=512,
|
image_size=512,
|
||||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
||||||
faces_per_pixel=100,
|
faces_per_pixel=100,
|
||||||
bin_size=0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
images = renderer(
|
|
||||||
mesh.clone(),
|
|
||||||
cameras=cameras,
|
|
||||||
raster_settings=raster_settings,
|
|
||||||
blend_params=blend_params,
|
|
||||||
)
|
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
|
||||||
|
|
||||||
# Load reference image
|
# Load reference image
|
||||||
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
|
image_ref = load_rgb_image("test_blurry_textured_rendering.png", DATA_DIR)
|
||||||
|
|
||||||
if DEBUG:
|
for bin_size in [0, None]:
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
# Check both naive and coarse to fine produce the same output.
|
||||||
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
|
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||||
|
|
||||||
|
images = renderer(
|
||||||
|
mesh.clone(),
|
||||||
|
cameras=cameras,
|
||||||
|
raster_settings=raster_settings,
|
||||||
|
blend_params=blend_params,
|
||||||
)
|
)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
if DEBUG:
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertClose(rgb, image_ref, atol=0.05)
|
173
tests/test_render_points.py
Normal file
173
tests/test_render_points.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Sanity checks for output images from the pointcloud renderer.
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
import warnings
|
||||||
|
from os import path
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from common_testing import TestCaseMixin, load_rgb_image
|
||||||
|
from PIL import Image
|
||||||
|
from pytorch3d.renderer.cameras import (
|
||||||
|
OpenGLOrthographicCameras,
|
||||||
|
OpenGLPerspectiveCameras,
|
||||||
|
look_at_view_transform,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.points import (
|
||||||
|
AlphaCompositor,
|
||||||
|
NormWeightedCompositor,
|
||||||
|
PointsRasterizationSettings,
|
||||||
|
PointsRasterizer,
|
||||||
|
PointsRenderer,
|
||||||
|
)
|
||||||
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||||
|
|
||||||
|
|
||||||
|
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||||
|
# All saved images have prefix DEBUG_
|
||||||
|
DEBUG = False
|
||||||
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
||||||
|
def test_simple_sphere(self):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
sphere_mesh = ico_sphere(1, device)
|
||||||
|
verts_padded = sphere_mesh.verts_padded()
|
||||||
|
# Shift vertices to check coordinate frames are correct.
|
||||||
|
verts_padded[..., 1] += 0.2
|
||||||
|
verts_padded[..., 0] += 0.2
|
||||||
|
pointclouds = Pointclouds(
|
||||||
|
points=verts_padded, features=torch.ones_like(verts_padded)
|
||||||
|
)
|
||||||
|
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||||
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
raster_settings = PointsRasterizationSettings(
|
||||||
|
image_size=256, radius=5e-2, points_per_pixel=1
|
||||||
|
)
|
||||||
|
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
|
compositor = NormWeightedCompositor()
|
||||||
|
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
|
||||||
|
|
||||||
|
# Load reference image
|
||||||
|
filename = "simple_pointcloud_sphere.png"
|
||||||
|
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||||
|
|
||||||
|
for bin_size in [0, None]:
|
||||||
|
# Check both naive and coarse to fine produce the same output.
|
||||||
|
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||||
|
images = renderer(pointclouds)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
filename = "DEBUG_%s" % filename
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / filename
|
||||||
|
)
|
||||||
|
self.assertClose(rgb, image_ref)
|
||||||
|
|
||||||
|
def test_pointcloud_with_features(self):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||||
|
pointcloud_filename = file_dir / "PittsburghBridge/pointcloud.npz"
|
||||||
|
|
||||||
|
# Note, this file is too large to check in to the repo.
|
||||||
|
# Download the file to run the test locally.
|
||||||
|
if not path.exists(pointcloud_filename):
|
||||||
|
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz"
|
||||||
|
msg = (
|
||||||
|
"pointcloud.npz not found, download from %s, save it at the path %s, and rerun"
|
||||||
|
% (url, pointcloud_filename)
|
||||||
|
)
|
||||||
|
warnings.warn(msg)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Load point cloud
|
||||||
|
pointcloud = np.load(pointcloud_filename)
|
||||||
|
verts = torch.Tensor(pointcloud["verts"]).to(device)
|
||||||
|
rgb_feats = torch.Tensor(pointcloud["rgb"]).to(device)
|
||||||
|
|
||||||
|
verts.requires_grad = True
|
||||||
|
rgb_feats.requires_grad = True
|
||||||
|
point_cloud = Pointclouds(points=[verts], features=[rgb_feats])
|
||||||
|
|
||||||
|
R, T = look_at_view_transform(20, 10, 0)
|
||||||
|
cameras = OpenGLOrthographicCameras(device=device, R=R, T=T, znear=0.01)
|
||||||
|
|
||||||
|
raster_settings = PointsRasterizationSettings(
|
||||||
|
# Set image_size so it is not a multiple of 16 (min bin_size)
|
||||||
|
# in order to confirm that there are no errors in coarse rasterization.
|
||||||
|
image_size=500,
|
||||||
|
radius=0.003,
|
||||||
|
points_per_pixel=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
renderer = PointsRenderer(
|
||||||
|
rasterizer=PointsRasterizer(
|
||||||
|
cameras=cameras, raster_settings=raster_settings
|
||||||
|
),
|
||||||
|
compositor=AlphaCompositor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
images = renderer(point_cloud)
|
||||||
|
|
||||||
|
# Load reference image
|
||||||
|
filename = "bridge_pointcloud.png"
|
||||||
|
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||||
|
|
||||||
|
for bin_size in [0, None]:
|
||||||
|
# Check both naive and coarse to fine produce the same output.
|
||||||
|
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||||
|
images = renderer(point_cloud)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
filename = "DEBUG_%s" % filename
|
||||||
|
Image.fromarray((rgb.detach().numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / filename
|
||||||
|
)
|
||||||
|
self.assertClose(rgb, image_ref, atol=0.015)
|
||||||
|
|
||||||
|
# Check grad exists.
|
||||||
|
grad_images = torch.randn_like(images)
|
||||||
|
images.backward(grad_images)
|
||||||
|
self.assertIsNotNone(verts.grad)
|
||||||
|
self.assertIsNotNone(rgb_feats.grad)
|
||||||
|
|
||||||
|
def test_simple_sphere_batched(self):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
sphere_mesh = ico_sphere(1, device)
|
||||||
|
verts_padded = sphere_mesh.verts_padded()
|
||||||
|
verts_padded[..., 1] += 0.2
|
||||||
|
verts_padded[..., 0] += 0.2
|
||||||
|
pointclouds = Pointclouds(
|
||||||
|
points=verts_padded, features=torch.ones_like(verts_padded)
|
||||||
|
)
|
||||||
|
batch_size = 20
|
||||||
|
pointclouds = pointclouds.extend(batch_size)
|
||||||
|
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||||
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
raster_settings = PointsRasterizationSettings(
|
||||||
|
image_size=256, radius=5e-2, points_per_pixel=1
|
||||||
|
)
|
||||||
|
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
|
compositor = NormWeightedCompositor()
|
||||||
|
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
|
||||||
|
|
||||||
|
# Load reference image
|
||||||
|
filename = "simple_pointcloud_sphere.png"
|
||||||
|
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||||
|
|
||||||
|
images = renderer(pointclouds)
|
||||||
|
for i in range(batch_size):
|
||||||
|
rgb = images[i, ..., :3].squeeze().cpu()
|
||||||
|
if i == 0 and DEBUG:
|
||||||
|
filename = "DEBUG_%s" % filename
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / filename
|
||||||
|
)
|
||||||
|
self.assertClose(rgb, image_ref)
|
Loading…
x
Reference in New Issue
Block a user