Fix coordinate system conventions in renderer
Summary: ## Updates - Defined the world and camera coordinates according to this figure. The world coordinates are defined as having +Y up, +X left and +Z in. {F230888499} - Removed all flipping from blending functions. - Updated the rasterizer to return images with +Y up and +X left. - Updated all the mesh rasterizer tests - The expected values are now defined in terms of the default +Y up, +X left - Added tests where the triangles in the meshes are non symmetrical so that it is clear which direction +X and +Y are ## Questions: - Should we have **scene settings** instead of raster settings? - To be more correct we should be [z clipping in the rasterizer based on the far/near clipping planes](https://github.com/ShichenLiu/SoftRas/blob/master/soft_renderer/cuda/soft_rasterize_cuda_kernel.cu#L400) - these values are also required in the blending functions so should we make these scene level parameters and have a scene settings tuple which is available to the rasterizer and shader? Reviewed By: gkioxari Differential Revision: D20208604 fbshipit-source-id: 55787301b1bffa0afa9618f0a0886cc681da51f3
Before Width: | Height: | Size: 150 KiB After Width: | Height: | Size: 66 KiB |
BIN
docs/notes/assets/world_camera_image.png
Normal file
After Width: | Height: | Size: 62 KiB |
@ -34,19 +34,22 @@ The differentiable renderer API is experimental and subject to change!.
|
|||||||
|
|
||||||
### Coordinate transformation conventions
|
### Coordinate transformation conventions
|
||||||
|
|
||||||
Rendering requires transformations between several different coordinate frames: world space, view/camera space, NDC space and screen space. At each step it is important to know where the camera is located, how the x,y,z axes are aligned and the possible range of values. The following figure outlines the conventions used PyTorch3d.
|
Rendering requires transformations between several different coordinate frames: world space, view/camera space, NDC space and screen space. At each step it is important to know where the camera is located, how the +X, +Y, +Z axes are aligned and the possible range of values. The following figure outlines the conventions used PyTorch3d.
|
||||||
|
|
||||||
<img src="assets/transformations_overview.png" width="1000">
|
<img src="assets/transformations_overview.png" width="1000">
|
||||||
|
|
||||||
|
|
||||||
|
For example, given a teapot mesh, the world coordinate frame, camera coordiante frame and image are show in the figure below. Note that the world and camera coordinate frames have the +z direction pointing in to the page.
|
||||||
|
|
||||||
|
<img src="assets/world_camera_image.png" width="1000">
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**NOTE: PyTorch3d vs OpenGL**
|
**NOTE: PyTorch3d vs OpenGL**
|
||||||
|
|
||||||
While we tried to emulate several aspects of OpenGL, the NDC coordinate system in PyTorch3d is **right-handed** compared with a **left-handed** NDC coordinate system in OpenGL (the projection matrix switches the handedness).
|
While we tried to emulate several aspects of OpenGL, there are differences in the coordinate frame conventions.
|
||||||
|
- The default world coordinate frame in PyTorch3D has +Z pointing in to the screen whereas in OpenGL, +Z is pointing out of the screen. Both are right handed.
|
||||||
In OpenGL, the camera at the origin is looking along `-z` axis in camera space, but it is looking along the `+z` axis in NDC space.
|
- The NDC coordinate system in PyTorch3d is **right-handed** compared with a **left-handed** NDC coordinate system in OpenGL (the projection matrix switches the handedness).
|
||||||
|
|
||||||
<img align="center" src="assets/opengl_coordframes.png" width="300">
|
<img align="center" src="assets/opengl_coordframes.png" width="300">
|
||||||
|
|
||||||
@ -60,7 +63,7 @@ A renderer in PyTorch3d is composed of a **rasterizer** and a **shader**. Create
|
|||||||
from pytorch3d.renderer import (
|
from pytorch3d.renderer import (
|
||||||
OpenGLPerspectiveCameras, look_at_view_transform,
|
OpenGLPerspectiveCameras, look_at_view_transform,
|
||||||
RasterizationSettings, BlendParams,
|
RasterizationSettings, BlendParams,
|
||||||
MeshRenderer, MeshRasterizer, PhongShader
|
MeshRenderer, MeshRasterizer, HardPhongShader
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize an OpenGL perspective camera.
|
# Initialize an OpenGL perspective camera.
|
||||||
@ -81,7 +84,7 @@ raster_settings = RasterizationSettings(
|
|||||||
# PhongShader, passing in the device on which to initialize the default parameters
|
# PhongShader, passing in the device on which to initialize the default parameters
|
||||||
renderer = MeshRenderer(
|
renderer = MeshRenderer(
|
||||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||||
shader=PhongShader(device=device, cameras=cameras)
|
shader=HardPhongShader(device=device, cameras=cameras)
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -189,12 +189,12 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
|||||||
const float* face_verts,
|
const float* face_verts,
|
||||||
const int64_t* mesh_to_face_first_idx,
|
const int64_t* mesh_to_face_first_idx,
|
||||||
const int64_t* num_faces_per_mesh,
|
const int64_t* num_faces_per_mesh,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
bool perspective_correct,
|
const bool perspective_correct,
|
||||||
int N,
|
const int N,
|
||||||
int H,
|
const int H,
|
||||||
int W,
|
const int W,
|
||||||
int K,
|
const int K,
|
||||||
int64_t* face_idxs,
|
int64_t* face_idxs,
|
||||||
float* zbuf,
|
float* zbuf,
|
||||||
float* pix_dists,
|
float* pix_dists,
|
||||||
@ -207,8 +207,10 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
|||||||
// Convert linear index to 3D index
|
// Convert linear index to 3D index
|
||||||
const int n = i / (H * W); // batch index.
|
const int n = i / (H * W); // batch index.
|
||||||
const int pix_idx = i % (H * W);
|
const int pix_idx = i % (H * W);
|
||||||
const int yi = pix_idx / H;
|
|
||||||
const int xi = pix_idx % W;
|
// Determine ordering based on axis convention.
|
||||||
|
const int yi = H - 1 - pix_idx / W;
|
||||||
|
const int xi = W - 1 - pix_idx % W;
|
||||||
|
|
||||||
// screen coordinates to ndc coordiantes of pixel.
|
// screen coordinates to ndc coordiantes of pixel.
|
||||||
const float xf = PixToNdc(xi, W);
|
const float xf = PixToNdc(xi, W);
|
||||||
@ -254,7 +256,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
|||||||
|
|
||||||
// TODO: make sorting an option as only top k is needed, not sorted values.
|
// TODO: make sorting an option as only top k is needed, not sorted values.
|
||||||
BubbleSort(q, q_size);
|
BubbleSort(q, q_size);
|
||||||
int idx = n * H * W * K + yi * H * K + xi * K;
|
int idx = n * H * W * K + pix_idx * K;
|
||||||
for (int k = 0; k < q_size; ++k) {
|
for (int k = 0; k < q_size; ++k) {
|
||||||
face_idxs[idx + k] = q[k].idx;
|
face_idxs[idx + k] = q[k].idx;
|
||||||
zbuf[idx + k] = q[k].z;
|
zbuf[idx + k] = q[k].z;
|
||||||
@ -274,7 +276,7 @@ RasterizeMeshesNaiveCuda(
|
|||||||
const int image_size,
|
const int image_size,
|
||||||
const float blur_radius,
|
const float blur_radius,
|
||||||
const int num_closest,
|
const int num_closest,
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||||
face_verts.size(2) != 3) {
|
face_verts.size(2) != 3) {
|
||||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||||
@ -331,12 +333,12 @@ RasterizeMeshesNaiveCuda(
|
|||||||
__global__ void RasterizeMeshesBackwardCudaKernel(
|
__global__ void RasterizeMeshesBackwardCudaKernel(
|
||||||
const float* face_verts, // (F, 3, 3)
|
const float* face_verts, // (F, 3, 3)
|
||||||
const int64_t* pix_to_face, // (N, H, W, K)
|
const int64_t* pix_to_face, // (N, H, W, K)
|
||||||
bool perspective_correct,
|
const bool perspective_correct,
|
||||||
int N,
|
const int N,
|
||||||
int F,
|
const int F,
|
||||||
int H,
|
const int H,
|
||||||
int W,
|
const int W,
|
||||||
int K,
|
const int K,
|
||||||
const float* grad_zbuf, // (N, H, W, K)
|
const float* grad_zbuf, // (N, H, W, K)
|
||||||
const float* grad_bary, // (N, H, W, K, 3)
|
const float* grad_bary, // (N, H, W, K, 3)
|
||||||
const float* grad_dists, // (N, H, W, K)
|
const float* grad_dists, // (N, H, W, K)
|
||||||
@ -351,8 +353,11 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
|||||||
// Convert linear index to 3D index
|
// Convert linear index to 3D index
|
||||||
const int n = t_i / (H * W); // batch index.
|
const int n = t_i / (H * W); // batch index.
|
||||||
const int pix_idx = t_i % (H * W);
|
const int pix_idx = t_i % (H * W);
|
||||||
const int yi = pix_idx / H;
|
|
||||||
const int xi = pix_idx % W;
|
// Determine ordering based on axis convention.
|
||||||
|
const int yi = H - 1 - pix_idx / W;
|
||||||
|
const int xi = W - 1 - pix_idx % W;
|
||||||
|
|
||||||
const float xf = PixToNdc(xi, W);
|
const float xf = PixToNdc(xi, W);
|
||||||
const float yf = PixToNdc(yi, H);
|
const float yf = PixToNdc(yi, H);
|
||||||
const float2 pxy = make_float2(xf, yf);
|
const float2 pxy = make_float2(xf, yf);
|
||||||
@ -360,8 +365,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
|||||||
// Loop over all the faces for this pixel.
|
// Loop over all the faces for this pixel.
|
||||||
for (int k = 0; k < K; k++) {
|
for (int k = 0; k < K; k++) {
|
||||||
// Index into (N, H, W, K, :) grad tensors
|
// Index into (N, H, W, K, :) grad tensors
|
||||||
const int i =
|
// pixel index + top k index
|
||||||
n * H * W * K + yi * H * K + xi * K + k; // pixel index + face index
|
int i = n * H * W * K + pix_idx * K + k;
|
||||||
|
|
||||||
const int f = pix_to_face[i];
|
const int f = pix_to_face[i];
|
||||||
if (f < 0) {
|
if (f < 0) {
|
||||||
@ -451,7 +456,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
|
|||||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||||
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||||
const torch::Tensor& grad_dists, // (N, H, W, K)
|
const torch::Tensor& grad_dists, // (N, H, W, K)
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
const int F = face_verts.size(0);
|
const int F = face_verts.size(0);
|
||||||
const int N = pix_to_face.size(0);
|
const int N = pix_to_face.size(0);
|
||||||
const int H = pix_to_face.size(1);
|
const int H = pix_to_face.size(1);
|
||||||
@ -509,6 +514,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
|||||||
// Have each block handle a chunk of faces
|
// Have each block handle a chunk of faces
|
||||||
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
|
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
|
||||||
const int num_chunks = N * chunks_per_batch;
|
const int num_chunks = N * chunks_per_batch;
|
||||||
|
|
||||||
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
|
||||||
const int batch_idx = chunk / chunks_per_batch; // batch index
|
const int batch_idx = chunk / chunks_per_batch; // batch index
|
||||||
const int chunk_idx = chunk % chunks_per_batch;
|
const int chunk_idx = chunk % chunks_per_batch;
|
||||||
@ -551,17 +557,21 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
|
|||||||
// Y coordinate of the top and bottom of the bin.
|
// Y coordinate of the top and bottom of the bin.
|
||||||
// PixToNdc gives the location of the center of each pixel, so we
|
// PixToNdc gives the location of the center of each pixel, so we
|
||||||
// need to add/subtract a half pixel to get the true extent of the bin.
|
// need to add/subtract a half pixel to get the true extent of the bin.
|
||||||
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
|
// Reverse ordering of Y axis so that +Y is upwards in the image.
|
||||||
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
|
const int yidx = num_bins - by;
|
||||||
|
float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
|
||||||
|
float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;
|
||||||
|
|
||||||
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
|
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
|
||||||
|
|
||||||
for (int bx = 0; bx < num_bins; ++bx) {
|
for (int bx = 0; bx < num_bins; ++bx) {
|
||||||
// X coordinate of the left and right of the bin.
|
// X coordinate of the left and right of the bin.
|
||||||
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
|
// Reverse ordering of x axis so that +X is left.
|
||||||
const float bin_x_max =
|
const int xidx = num_bins - bx;
|
||||||
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
|
float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
|
||||||
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
|
float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;
|
||||||
|
|
||||||
|
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
|
||||||
if (y_overlap && x_overlap) {
|
if (y_overlap && x_overlap) {
|
||||||
binmask.set(by, bx, f);
|
binmask.set(by, bx, f);
|
||||||
}
|
}
|
||||||
@ -654,7 +664,6 @@ torch::Tensor RasterizeMeshesCoarseCuda(
|
|||||||
// ****************************************************************************
|
// ****************************************************************************
|
||||||
// * FINE RASTERIZATION *
|
// * FINE RASTERIZATION *
|
||||||
// ****************************************************************************
|
// ****************************************************************************
|
||||||
|
|
||||||
__global__ void RasterizeMeshesFineCudaKernel(
|
__global__ void RasterizeMeshesFineCudaKernel(
|
||||||
const float* face_verts, // (F, 3, 3)
|
const float* face_verts, // (F, 3, 3)
|
||||||
const int32_t* bin_faces, // (N, B, B, T)
|
const int32_t* bin_faces, // (N, B, B, T)
|
||||||
@ -695,8 +704,14 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
|||||||
|
|
||||||
if (yi >= H || xi >= W)
|
if (yi >= H || xi >= W)
|
||||||
continue;
|
continue;
|
||||||
const float xf = PixToNdc(xi, W);
|
|
||||||
const float yf = PixToNdc(yi, H);
|
// Reverse ordering of the X and Y axis so that
|
||||||
|
// in the image +Y is pointing up and +X is pointing left.
|
||||||
|
const int yidx = H - 1 - yi;
|
||||||
|
const int xidx = W - 1 - xi;
|
||||||
|
|
||||||
|
const float xf = PixToNdc(xidx, W);
|
||||||
|
const float yf = PixToNdc(yidx, H);
|
||||||
const float2 pxy = make_float2(xf, yf);
|
const float2 pxy = make_float2(xf, yf);
|
||||||
|
|
||||||
// This part looks like the naive rasterization kernel, except we use
|
// This part looks like the naive rasterization kernel, except we use
|
||||||
@ -751,7 +766,7 @@ RasterizeMeshesFineCuda(
|
|||||||
const float blur_radius,
|
const float blur_radius,
|
||||||
const int bin_size,
|
const int bin_size,
|
||||||
const int faces_per_pixel,
|
const int faces_per_pixel,
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||||
face_verts.size(2) != 3) {
|
face_verts.size(2) != 3) {
|
||||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||||
|
@ -14,10 +14,10 @@ RasterizeMeshesNaiveCpu(
|
|||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const torch::Tensor& mesh_to_face_first_idx,
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
const torch::Tensor& num_faces_per_mesh,
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int faces_per_pixel,
|
const int faces_per_pixel,
|
||||||
bool perspective_correct);
|
const bool perspective_correct);
|
||||||
|
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||||
@ -25,10 +25,10 @@ RasterizeMeshesNaiveCuda(
|
|||||||
const at::Tensor& face_verts,
|
const at::Tensor& face_verts,
|
||||||
const at::Tensor& mesh_to_face_first_idx,
|
const at::Tensor& mesh_to_face_first_idx,
|
||||||
const at::Tensor& num_faces_per_mesh,
|
const at::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int num_closest,
|
const int num_closest,
|
||||||
bool perspective_correct);
|
const bool perspective_correct);
|
||||||
#endif
|
#endif
|
||||||
// Forward pass for rasterizing a batch of meshes.
|
// Forward pass for rasterizing a batch of meshes.
|
||||||
//
|
//
|
||||||
@ -77,10 +77,10 @@ RasterizeMeshesNaive(
|
|||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const torch::Tensor& mesh_to_face_first_idx,
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
const torch::Tensor& num_faces_per_mesh,
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int faces_per_pixel,
|
const int faces_per_pixel,
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
// TODO: Better type checking.
|
// TODO: Better type checking.
|
||||||
if (face_verts.type().is_cuda()) {
|
if (face_verts.type().is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
@ -117,7 +117,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
|||||||
const torch::Tensor& grad_bary,
|
const torch::Tensor& grad_bary,
|
||||||
const torch::Tensor& grad_zbuf,
|
const torch::Tensor& grad_zbuf,
|
||||||
const torch::Tensor& grad_dists,
|
const torch::Tensor& grad_dists,
|
||||||
bool perspective_correct);
|
const bool perspective_correct);
|
||||||
|
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
torch::Tensor RasterizeMeshesBackwardCuda(
|
torch::Tensor RasterizeMeshesBackwardCuda(
|
||||||
@ -126,7 +126,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
|
|||||||
const torch::Tensor& grad_bary,
|
const torch::Tensor& grad_bary,
|
||||||
const torch::Tensor& grad_zbuf,
|
const torch::Tensor& grad_zbuf,
|
||||||
const torch::Tensor& grad_dists,
|
const torch::Tensor& grad_dists,
|
||||||
bool perspective_correct);
|
const bool perspective_correct);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Args:
|
// Args:
|
||||||
@ -159,7 +159,7 @@ torch::Tensor RasterizeMeshesBackward(
|
|||||||
const torch::Tensor& grad_zbuf,
|
const torch::Tensor& grad_zbuf,
|
||||||
const torch::Tensor& grad_bary,
|
const torch::Tensor& grad_bary,
|
||||||
const torch::Tensor& grad_dists,
|
const torch::Tensor& grad_dists,
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
if (face_verts.type().is_cuda()) {
|
if (face_verts.type().is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
return RasterizeMeshesBackwardCuda(
|
return RasterizeMeshesBackwardCuda(
|
||||||
@ -191,20 +191,20 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
|||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const at::Tensor& mesh_to_face_first_idx,
|
const at::Tensor& mesh_to_face_first_idx,
|
||||||
const at::Tensor& num_faces_per_mesh,
|
const at::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int bin_size,
|
const int bin_size,
|
||||||
int max_faces_per_bin);
|
const int max_faces_per_bin);
|
||||||
|
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
torch::Tensor RasterizeMeshesCoarseCuda(
|
torch::Tensor RasterizeMeshesCoarseCuda(
|
||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const torch::Tensor& mesh_to_face_first_idx,
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
const torch::Tensor& num_faces_per_mesh,
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int bin_size,
|
const int bin_size,
|
||||||
int max_faces_per_bin);
|
const int max_faces_per_bin);
|
||||||
#endif
|
#endif
|
||||||
// Args:
|
// Args:
|
||||||
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
|
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
|
||||||
@ -232,10 +232,10 @@ torch::Tensor RasterizeMeshesCoarse(
|
|||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const torch::Tensor& mesh_to_face_first_idx,
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
const torch::Tensor& num_faces_per_mesh,
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int bin_size,
|
const int bin_size,
|
||||||
int max_faces_per_bin) {
|
const int max_faces_per_bin) {
|
||||||
if (face_verts.type().is_cuda()) {
|
if (face_verts.type().is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
return RasterizeMeshesCoarseCuda(
|
return RasterizeMeshesCoarseCuda(
|
||||||
@ -270,11 +270,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
|||||||
RasterizeMeshesFineCuda(
|
RasterizeMeshesFineCuda(
|
||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const torch::Tensor& bin_faces,
|
const torch::Tensor& bin_faces,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int bin_size,
|
const int bin_size,
|
||||||
int faces_per_pixel,
|
const int faces_per_pixel,
|
||||||
bool perspective_correct);
|
const bool perspective_correct);
|
||||||
#endif
|
#endif
|
||||||
// Args:
|
// Args:
|
||||||
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
|
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
|
||||||
@ -317,11 +317,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
|
|||||||
RasterizeMeshesFine(
|
RasterizeMeshesFine(
|
||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const torch::Tensor& bin_faces,
|
const torch::Tensor& bin_faces,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int bin_size,
|
const int bin_size,
|
||||||
int faces_per_pixel,
|
const int faces_per_pixel,
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
if (face_verts.type().is_cuda()) {
|
if (face_verts.type().is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
return RasterizeMeshesFineCuda(
|
return RasterizeMeshesFineCuda(
|
||||||
@ -373,6 +373,7 @@ RasterizeMeshesFine(
|
|||||||
// this function instead returns screen-space
|
// this function instead returns screen-space
|
||||||
// barycentric coordinates for each pixel.
|
// barycentric coordinates for each pixel.
|
||||||
//
|
//
|
||||||
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// A 4 element tuple of:
|
// A 4 element tuple of:
|
||||||
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
|
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
|
||||||
@ -394,12 +395,12 @@ RasterizeMeshes(
|
|||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const torch::Tensor& mesh_to_face_first_idx,
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
const torch::Tensor& num_faces_per_mesh,
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int faces_per_pixel,
|
const int faces_per_pixel,
|
||||||
int bin_size,
|
const int bin_size,
|
||||||
int max_faces_per_bin,
|
const int max_faces_per_bin,
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
if (bin_size > 0 && max_faces_per_bin > 0) {
|
if (bin_size > 0 && max_faces_per_bin > 0) {
|
||||||
// Use coarse-to-fine rasterization
|
// Use coarse-to-fine rasterization
|
||||||
auto bin_faces = RasterizeMeshesCoarse(
|
auto bin_faces = RasterizeMeshesCoarse(
|
||||||
|
@ -105,9 +105,9 @@ RasterizeMeshesNaiveCpu(
|
|||||||
const torch::Tensor& mesh_to_face_first_idx,
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
const torch::Tensor& num_faces_per_mesh,
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int faces_per_pixel,
|
const int faces_per_pixel,
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||||
face_verts.size(2) != 3) {
|
face_verts.size(2) != 3) {
|
||||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||||
@ -153,12 +153,19 @@ RasterizeMeshesNaiveCpu(
|
|||||||
|
|
||||||
// Iterate through the horizontal lines of the image from top to bottom.
|
// Iterate through the horizontal lines of the image from top to bottom.
|
||||||
for (int yi = 0; yi < H; ++yi) {
|
for (int yi = 0; yi < H; ++yi) {
|
||||||
|
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||||
|
const int yidx = H - 1 - yi;
|
||||||
|
|
||||||
// Y coordinate of the top of the pixel.
|
// Y coordinate of the top of the pixel.
|
||||||
const float yf = PixToNdc(yi, H);
|
const float yf = PixToNdc(yidx, H);
|
||||||
// Iterate through pixels on this horizontal line, left to right.
|
// Iterate through pixels on this horizontal line, left to right.
|
||||||
for (int xi = 0; xi < W; ++xi) {
|
for (int xi = 0; xi < W; ++xi) {
|
||||||
|
// Reverse the order of xi so that +X is pointing to the left in the
|
||||||
|
// image.
|
||||||
|
const int xidx = W - 1 - xi;
|
||||||
|
|
||||||
// X coordinate of the left of the pixel.
|
// X coordinate of the left of the pixel.
|
||||||
const float xf = PixToNdc(xi, W);
|
const float xf = PixToNdc(xidx, W);
|
||||||
// Use a priority queue to hold values:
|
// Use a priority queue to hold values:
|
||||||
// (z, idx, r, bary.x, bary.y. bary.z)
|
// (z, idx, r, bary.x, bary.y. bary.z)
|
||||||
std::priority_queue<std::tuple<float, int, float, float, float, float>>
|
std::priority_queue<std::tuple<float, int, float, float, float, float>>
|
||||||
@ -250,7 +257,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
|||||||
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
const torch::Tensor& grad_zbuf, // (N, H, W, K)
|
||||||
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
|
||||||
const torch::Tensor& grad_dists, // (N, H, W, K)
|
const torch::Tensor& grad_dists, // (N, H, W, K)
|
||||||
bool perspective_correct) {
|
const bool perspective_correct) {
|
||||||
const int F = face_verts.size(0);
|
const int F = face_verts.size(0);
|
||||||
const int N = pix_to_face.size(0);
|
const int N = pix_to_face.size(0);
|
||||||
const int H = pix_to_face.size(1);
|
const int H = pix_to_face.size(1);
|
||||||
@ -267,12 +274,19 @@ torch::Tensor RasterizeMeshesBackwardCpu(
|
|||||||
for (int n = 0; n < N; ++n) {
|
for (int n = 0; n < N; ++n) {
|
||||||
// Iterate through the horizontal lines of the image from top to bottom.
|
// Iterate through the horizontal lines of the image from top to bottom.
|
||||||
for (int y = 0; y < H; ++y) {
|
for (int y = 0; y < H; ++y) {
|
||||||
|
// Reverse the order of yi so that +Y is pointing upwards in the image.
|
||||||
|
const int yidx = H - 1 - y;
|
||||||
|
|
||||||
// Y coordinate of the top of the pixel.
|
// Y coordinate of the top of the pixel.
|
||||||
const float yf = PixToNdc(y, H);
|
const float yf = PixToNdc(yidx, H);
|
||||||
// Iterate through pixels on this horizontal line, left to right.
|
// Iterate through pixels on this horizontal line, left to right.
|
||||||
for (int x = 0; x < W; ++x) {
|
for (int x = 0; x < W; ++x) {
|
||||||
|
// Reverse the order of xi so that +X is pointing to the left in the
|
||||||
|
// image.
|
||||||
|
const int xidx = W - 1 - x;
|
||||||
|
|
||||||
// X coordinate of the left of the pixel.
|
// X coordinate of the left of the pixel.
|
||||||
const float xf = PixToNdc(x, W);
|
const float xf = PixToNdc(xidx, W);
|
||||||
const vec2<float> pxy(xf, yf);
|
const vec2<float> pxy(xf, yf);
|
||||||
|
|
||||||
// Iterate through the faces that hit this pixel.
|
// Iterate through the faces that hit this pixel.
|
||||||
@ -376,10 +390,10 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
|||||||
const torch::Tensor& face_verts,
|
const torch::Tensor& face_verts,
|
||||||
const torch::Tensor& mesh_to_face_first_idx,
|
const torch::Tensor& mesh_to_face_first_idx,
|
||||||
const torch::Tensor& num_faces_per_mesh,
|
const torch::Tensor& num_faces_per_mesh,
|
||||||
int image_size,
|
const int image_size,
|
||||||
float blur_radius,
|
const float blur_radius,
|
||||||
int bin_size,
|
const int bin_size,
|
||||||
int max_faces_per_bin) {
|
const int max_faces_per_bin) {
|
||||||
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
|
||||||
face_verts.size(2) != 3) {
|
face_verts.size(2) != 3) {
|
||||||
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
|
||||||
@ -387,6 +401,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
|||||||
if (num_faces_per_mesh.ndimension() != 1) {
|
if (num_faces_per_mesh.ndimension() != 1) {
|
||||||
AT_ERROR("num_faces_per_mesh can only have one dimension");
|
AT_ERROR("num_faces_per_mesh can only have one dimension");
|
||||||
}
|
}
|
||||||
|
|
||||||
const int N = num_faces_per_mesh.size(0); // batch size.
|
const int N = num_faces_per_mesh.size(0); // batch size.
|
||||||
const int M = max_faces_per_bin;
|
const int M = max_faces_per_bin;
|
||||||
|
|
||||||
@ -415,13 +430,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
|||||||
const int face_stop_idx =
|
const int face_stop_idx =
|
||||||
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
|
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
|
||||||
|
|
||||||
float bin_y_min = -1.0f;
|
float bin_y_max = 1.0f;
|
||||||
float bin_y_max = bin_y_min + bin_width;
|
float bin_y_min = bin_y_max - bin_width;
|
||||||
|
|
||||||
// Iterate through the horizontal bins from top to bottom.
|
// Iterate through the horizontal bins from top to bottom.
|
||||||
for (int by = 0; by < BH; ++by) {
|
for (int by = 0; by < BH; ++by) {
|
||||||
float bin_x_min = -1.0f;
|
float bin_x_max = 1.0f;
|
||||||
float bin_x_max = bin_x_min + bin_width;
|
float bin_x_min = bin_x_max - bin_width;
|
||||||
|
|
||||||
// Iterate through bins on this horizontal line, left to right.
|
// Iterate through bins on this horizontal line, left to right.
|
||||||
for (int bx = 0; bx < BW; ++bx) {
|
for (int bx = 0; bx < BW; ++bx) {
|
||||||
@ -458,13 +473,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shift the bin to the right for the next loop iteration.
|
// Shift the bin down for the next loop iteration.
|
||||||
bin_x_min = bin_x_max;
|
bin_x_max = bin_x_min;
|
||||||
bin_x_max = bin_x_min + bin_width;
|
bin_x_min = bin_x_min - bin_width;
|
||||||
}
|
}
|
||||||
// Shift the bin down for the next loop iteration.
|
// Shift the bin left for the next loop iteration.
|
||||||
bin_y_min = bin_y_max;
|
bin_y_max = bin_y_min;
|
||||||
bin_y_max = bin_y_min + bin_width;
|
bin_y_min = bin_y_min - bin_width;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return bin_faces;
|
return bin_faces;
|
||||||
|
@ -38,7 +38,7 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor:
|
|||||||
device = fragments.pix_to_face.device
|
device = fragments.pix_to_face.device
|
||||||
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
|
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
|
||||||
pixel_colors[..., :3] = colors[..., 0, :]
|
pixel_colors[..., :3] = colors[..., 0, :]
|
||||||
return torch.flip(pixel_colors, [1])
|
return pixel_colors
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||||
@ -80,7 +80,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
|||||||
alpha = torch.prod((1.0 - prob), dim=-1)
|
alpha = torch.prod((1.0 - prob), dim=-1)
|
||||||
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
|
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
|
||||||
pixel_colors[..., 3] = 1.0 - alpha
|
pixel_colors[..., 3] = 1.0 - alpha
|
||||||
return torch.flip(pixel_colors, [1])
|
return pixel_colors
|
||||||
|
|
||||||
|
|
||||||
def softmax_rgb_blend(
|
def softmax_rgb_blend(
|
||||||
@ -125,7 +125,7 @@ def softmax_rgb_blend(
|
|||||||
|
|
||||||
N, H, W, K = fragments.pix_to_face.shape
|
N, H, W, K = fragments.pix_to_face.shape
|
||||||
device = fragments.pix_to_face.device
|
device = fragments.pix_to_face.device
|
||||||
pix_colors = torch.ones(
|
pixel_colors = torch.ones(
|
||||||
(N, H, W, 4), dtype=colors.dtype, device=colors.device
|
(N, H, W, 4), dtype=colors.dtype, device=colors.device
|
||||||
)
|
)
|
||||||
background = blend_params.background_color
|
background = blend_params.background_color
|
||||||
@ -166,7 +166,7 @@ def softmax_rgb_blend(
|
|||||||
# Sum: weights * textures + background color
|
# Sum: weights * textures + background color
|
||||||
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
|
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
|
||||||
weighted_background = (delta / denom) * background
|
weighted_background = (delta / denom) * background
|
||||||
pix_colors[..., :3] = weighted_colors + weighted_background
|
pixel_colors[..., :3] = weighted_colors + weighted_background
|
||||||
pix_colors[..., 3] = 1.0 - alpha
|
pixel_colors[..., 3] = 1.0 - alpha
|
||||||
|
|
||||||
return torch.flip(pix_colors, [1])
|
return pixel_colors
|
||||||
|
@ -944,7 +944,7 @@ def camera_position_from_spherical_angles(
|
|||||||
azim = math.pi / 180.0 * azim
|
azim = math.pi / 180.0 * azim
|
||||||
x = dist * torch.cos(elev) * torch.sin(azim)
|
x = dist * torch.cos(elev) * torch.sin(azim)
|
||||||
y = dist * torch.sin(elev)
|
y = dist * torch.sin(elev)
|
||||||
z = -dist * torch.cos(elev) * torch.cos(azim)
|
z = dist * torch.cos(elev) * torch.cos(azim)
|
||||||
camera_position = torch.stack([x, y, z], dim=1)
|
camera_position = torch.stack([x, y, z], dim=1)
|
||||||
if camera_position.dim() == 0:
|
if camera_position.dim() == 0:
|
||||||
camera_position = camera_position.view(1, -1) # add batch dim.
|
camera_position = camera_position.view(1, -1) # add batch dim.
|
||||||
|
@ -208,6 +208,11 @@ class _RasterizeFaceVerts(torch.autograd.Function):
|
|||||||
return grads
|
return grads
|
||||||
|
|
||||||
|
|
||||||
|
def pix_to_ndc(i, S):
|
||||||
|
# NDC x-offset + (i * pixel_width + half_pixel_width)
|
||||||
|
return -1 + (2 * i + 1.0) / S
|
||||||
|
|
||||||
|
|
||||||
def rasterize_meshes_python(
|
def rasterize_meshes_python(
|
||||||
meshes,
|
meshes,
|
||||||
image_size: int = 256,
|
image_size: int = 256,
|
||||||
@ -249,10 +254,6 @@ def rasterize_meshes_python(
|
|||||||
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
|
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# NDC is from [-1, 1]. Get pixel size using specified image size.
|
|
||||||
pixel_width = 2.0 / W
|
|
||||||
pixel_height = 2.0 / H
|
|
||||||
|
|
||||||
# Calculate all face bounding boxes.
|
# Calculate all face bounding boxes.
|
||||||
x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values
|
x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values
|
||||||
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
|
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
|
||||||
@ -269,14 +270,20 @@ def rasterize_meshes_python(
|
|||||||
for n in range(N):
|
for n in range(N):
|
||||||
face_start_idx = mesh_to_face_first_idx[n]
|
face_start_idx = mesh_to_face_first_idx[n]
|
||||||
face_stop_idx = face_start_idx + num_faces_per_mesh[n]
|
face_stop_idx = face_start_idx + num_faces_per_mesh[n]
|
||||||
# Y coordinate of the top of the image.
|
|
||||||
yf = -1.0 + 0.5 * pixel_height
|
|
||||||
# Iterate through the horizontal lines of the image from top to bottom.
|
# Iterate through the horizontal lines of the image from top to bottom.
|
||||||
for yi in range(H):
|
for yi in range(H):
|
||||||
# X coordinate of the left of the image.
|
# Y coordinate of one end of the image. Reverse the ordering
|
||||||
xf = -1.0 + 0.5 * pixel_width
|
# of yi so that +Y is pointing up in the image.
|
||||||
|
yfix = H - 1 - yi
|
||||||
|
yf = pix_to_ndc(yfix, H)
|
||||||
|
|
||||||
# Iterate through pixels on this horizontal line, left to right.
|
# Iterate through pixels on this horizontal line, left to right.
|
||||||
for xi in range(W):
|
for xi in range(W):
|
||||||
|
# X coordinate of one end of the image. Reverse the ordering
|
||||||
|
# of xi so that +X is pointing to the left in the image.
|
||||||
|
xfix = W - 1 - xi
|
||||||
|
xf = pix_to_ndc(xfix, H)
|
||||||
top_k_points = []
|
top_k_points = []
|
||||||
|
|
||||||
# Check whether each face in the mesh affects this pixel.
|
# Check whether each face in the mesh affects this pixel.
|
||||||
@ -347,12 +354,6 @@ def rasterize_meshes_python(
|
|||||||
bary_coords[n, yi, xi, k, 2] = bary[2]
|
bary_coords[n, yi, xi, k, 2] = bary[2]
|
||||||
pix_dists[n, yi, xi, k] = dist
|
pix_dists[n, yi, xi, k] = dist
|
||||||
|
|
||||||
# Move to the next horizontal pixel
|
|
||||||
xf += pixel_width
|
|
||||||
|
|
||||||
# Move to the next vertical pixel
|
|
||||||
yf += pixel_height
|
|
||||||
|
|
||||||
return face_idxs, zbuf, bary_coords, pix_dists
|
return face_idxs, zbuf, bary_coords, pix_dists
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class MeshRenderer(nn.Module):
|
|||||||
if raster_settings.blur_radius > 0.0:
|
if raster_settings.blur_radius > 0.0:
|
||||||
# TODO: potentially move barycentric clipping to the rasterizer
|
# TODO: potentially move barycentric clipping to the rasterizer
|
||||||
# if no downstream functions requires unclipped values.
|
# if no downstream functions requires unclipped values.
|
||||||
# This will avoid unnecssary re-interpolation of the z buffer.
|
# This will avoid unnecssary re-interpolation of the z buffer.
|
||||||
clipped_bary_coords = _clip_barycentric_coordinates(
|
clipped_bary_coords = _clip_barycentric_coordinates(
|
||||||
fragments.bary_coords
|
fragments.bary_coords
|
||||||
)
|
)
|
||||||
@ -67,4 +67,5 @@ class MeshRenderer(nn.Module):
|
|||||||
pix_to_face=fragments.pix_to_face,
|
pix_to_face=fragments.pix_to_face,
|
||||||
)
|
)
|
||||||
images = self.shader(fragments, meshes_world, **kwargs)
|
images = self.shader(fragments, meshes_world, **kwargs)
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
Before Width: | Height: | Size: 45 KiB After Width: | Height: | Size: 43 KiB |
Before Width: | Height: | Size: 8.9 KiB After Width: | Height: | Size: 8.8 KiB |
Before Width: | Height: | Size: 3.4 KiB After Width: | Height: | Size: 3.8 KiB |
Before Width: | Height: | Size: 16 KiB |
Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 21 KiB |
BIN
tests/data/test_simple_sphere_light_elevated_camera.png
Normal file
After Width: | Height: | Size: 10 KiB |
Before Width: | Height: | Size: 18 KiB After Width: | Height: | Size: 15 KiB |
Before Width: | Height: | Size: 16 KiB After Width: | Height: | Size: 10 KiB |
Before Width: | Height: | Size: 31 KiB |
BIN
tests/data/test_texture_map_back.png
Normal file
After Width: | Height: | Size: 31 KiB |
BIN
tests/data/test_texture_map_front.png
Normal file
After Width: | Height: | Size: 30 KiB |
@ -41,7 +41,7 @@ def sigmoid_blend_naive_loop(colors, fragments, blend_params):
|
|||||||
pixel_colors[n, h, w, :3] = colors[n, h, w, 0, :]
|
pixel_colors[n, h, w, :3] = colors[n, h, w, 0, :]
|
||||||
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
||||||
|
|
||||||
return torch.flip(pixel_colors, [1])
|
return pixel_colors
|
||||||
|
|
||||||
|
|
||||||
def sigmoid_blend_naive_loop_backward(
|
def sigmoid_blend_naive_loop_backward(
|
||||||
@ -54,8 +54,6 @@ def sigmoid_blend_naive_loop_backward(
|
|||||||
N, H, W, K = pix_to_face.shape
|
N, H, W, K = pix_to_face.shape
|
||||||
device = pix_to_face.device
|
device = pix_to_face.device
|
||||||
grad_distances = torch.zeros((N, H, W, K), dtype=dists.dtype, device=device)
|
grad_distances = torch.zeros((N, H, W, K), dtype=dists.dtype, device=device)
|
||||||
images = torch.flip(images, [1])
|
|
||||||
grad_images = torch.flip(grad_images, [1])
|
|
||||||
|
|
||||||
for n in range(N):
|
for n in range(N):
|
||||||
for h in range(H):
|
for h in range(H):
|
||||||
@ -130,7 +128,7 @@ def softmax_blend_naive(colors, fragments, blend_params):
|
|||||||
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
|
pixel_colors[n, h, w, :3] += (delta / denom) * bk_color
|
||||||
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
pixel_colors[n, h, w, 3] = 1.0 - alpha
|
||||||
|
|
||||||
return torch.flip(pixel_colors, [1])
|
return pixel_colors
|
||||||
|
|
||||||
|
|
||||||
class TestBlending(unittest.TestCase):
|
class TestBlending(unittest.TestCase):
|
||||||
|
@ -173,12 +173,12 @@ class TestCameraHelpers(unittest.TestCase):
|
|||||||
grad_dist = (
|
grad_dist = (
|
||||||
torch.cos(elev) * torch.sin(azim)
|
torch.cos(elev) * torch.sin(azim)
|
||||||
+ torch.sin(elev)
|
+ torch.sin(elev)
|
||||||
- torch.cos(elev) * torch.cos(azim)
|
+ torch.cos(elev) * torch.cos(azim)
|
||||||
)
|
)
|
||||||
grad_elev = (
|
grad_elev = (
|
||||||
-torch.sin(elev) * torch.sin(azim)
|
-torch.sin(elev) * torch.sin(azim)
|
||||||
+ torch.cos(elev)
|
+ torch.cos(elev)
|
||||||
+ torch.sin(elev) * torch.cos(azim)
|
- torch.sin(elev) * torch.cos(azim)
|
||||||
)
|
)
|
||||||
grad_elev = dist * (math.pi / 180.0) * grad_elev
|
grad_elev = dist * (math.pi / 180.0) * grad_elev
|
||||||
self.assertTrue(torch.allclose(elev_grad, grad_elev))
|
self.assertTrue(torch.allclose(elev_grad, grad_elev))
|
||||||
@ -232,12 +232,12 @@ class TestCameraHelpers(unittest.TestCase):
|
|||||||
grad_dist = (
|
grad_dist = (
|
||||||
torch.cos(elev) * torch.sin(azim)
|
torch.cos(elev) * torch.sin(azim)
|
||||||
+ torch.sin(elev)
|
+ torch.sin(elev)
|
||||||
- torch.cos(elev) * torch.cos(azim)
|
+ torch.cos(elev) * torch.cos(azim)
|
||||||
)
|
)
|
||||||
grad_elev = (
|
grad_elev = (
|
||||||
-torch.sin(elev) * torch.sin(azim)
|
-torch.sin(elev) * torch.sin(azim)
|
||||||
+ torch.cos(elev)
|
+ torch.cos(elev)
|
||||||
+ torch.sin(elev) * torch.cos(azim)
|
- torch.sin(elev) * torch.cos(azim)
|
||||||
)
|
)
|
||||||
grad_elev = (dist * (math.pi / 180.0) * grad_elev).sum()
|
grad_elev = (dist * (math.pi / 180.0) * grad_elev).sum()
|
||||||
self.assertTrue(torch.allclose(elev_grad, grad_elev))
|
self.assertTrue(torch.allclose(elev_grad, grad_elev))
|
||||||
|
@ -19,7 +19,7 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
self._simple_triangle_raster(
|
self._simple_triangle_raster(
|
||||||
rasterize_meshes_python, device, bin_size=-1
|
rasterize_meshes_python, device, bin_size=-1
|
||||||
) # don't set binsize
|
)
|
||||||
self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1)
|
self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1)
|
||||||
self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1)
|
self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1)
|
||||||
self._test_perspective_correct(
|
self._test_perspective_correct(
|
||||||
@ -28,10 +28,10 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
def test_simple_cpu_naive(self):
|
def test_simple_cpu_naive(self):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
self._simple_triangle_raster(rasterize_meshes, device)
|
self._simple_triangle_raster(rasterize_meshes, device, bin_size=0)
|
||||||
self._simple_blurry_raster(rasterize_meshes, device)
|
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
|
||||||
self._test_behind_camera(rasterize_meshes, device)
|
self._test_behind_camera(rasterize_meshes, device, bin_size=0)
|
||||||
self._test_perspective_correct(rasterize_meshes, device)
|
self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
|
||||||
|
|
||||||
def test_simple_cuda_naive(self):
|
def test_simple_cuda_naive(self):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@ -285,14 +285,15 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
max_faces_per_bin = 20
|
max_faces_per_bin = 20
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
meshes = ico_sphere(2, device)
|
|
||||||
|
|
||||||
|
meshes = ico_sphere(2, device)
|
||||||
faces = meshes.faces_packed()
|
faces = meshes.faces_packed()
|
||||||
verts = meshes.verts_packed()
|
verts = meshes.verts_packed()
|
||||||
faces_verts = verts[faces]
|
faces_verts = verts[faces]
|
||||||
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
||||||
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||||
args = (
|
|
||||||
|
bin_faces_cpu = _C._rasterize_meshes_coarse(
|
||||||
faces_verts,
|
faces_verts,
|
||||||
mesh_to_face_first_idx,
|
mesh_to_face_first_idx,
|
||||||
num_faces_per_mesh,
|
num_faces_per_mesh,
|
||||||
@ -301,17 +302,16 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
bin_size,
|
bin_size,
|
||||||
max_faces_per_bin,
|
max_faces_per_bin,
|
||||||
)
|
)
|
||||||
bin_faces_cpu = _C._rasterize_meshes_coarse(*args)
|
|
||||||
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
meshes = ico_sphere(2, device)
|
meshes = meshes.clone().to(device)
|
||||||
|
|
||||||
faces = meshes.faces_packed()
|
faces = meshes.faces_packed()
|
||||||
verts = meshes.verts_packed()
|
verts = meshes.verts_packed()
|
||||||
faces_verts = verts[faces]
|
faces_verts = verts[faces]
|
||||||
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
||||||
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||||
args = (
|
|
||||||
|
bin_faces_cuda = _C._rasterize_meshes_coarse(
|
||||||
faces_verts,
|
faces_verts,
|
||||||
mesh_to_face_first_idx,
|
mesh_to_face_first_idx,
|
||||||
num_faces_per_mesh,
|
num_faces_per_mesh,
|
||||||
@ -320,11 +320,11 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
bin_size,
|
bin_size,
|
||||||
max_faces_per_bin,
|
max_faces_per_bin,
|
||||||
)
|
)
|
||||||
bin_faces_cuda = _C._rasterize_meshes_coarse(*args)
|
|
||||||
|
|
||||||
# Bin faces might not be the same: CUDA version might write them in
|
# Bin faces might not be the same: CUDA version might write them in
|
||||||
# any order. But if we sort the non-(-1) elements of the CUDA output
|
# any order. But if we sort the non-(-1) elements of the CUDA output
|
||||||
# then they should be the same.
|
# then they should be the same.
|
||||||
|
|
||||||
for n in range(N):
|
for n in range(N):
|
||||||
for by in range(bin_faces_cpu.shape[1]):
|
for by in range(bin_faces_cpu.shape[1]):
|
||||||
for bx in range(bin_faces_cpu.shape[2]):
|
for bx in range(bin_faces_cpu.shape[2]):
|
||||||
@ -456,64 +456,68 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
# Run with and without perspective correction
|
# Run with and without perspective correction
|
||||||
idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs)
|
idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs)
|
||||||
|
|
||||||
kwargs["perspective_correct"] = True
|
kwargs["perspective_correct"] = True
|
||||||
idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs)
|
idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs)
|
||||||
|
|
||||||
|
# Expected output tensors in the format with axes +X left, +Y up, +Z in
|
||||||
# idx and dists should be the same with or without perspecitve correction
|
# idx and dists should be the same with or without perspecitve correction
|
||||||
# fmt: off
|
# fmt: off
|
||||||
idx_expected = torch.tensor([
|
idx_expected = torch.tensor([
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, 0, 0, 0, 0, 0, 0, 0, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], # noqa: E241, E201
|
|
||||||
[-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], # noqa: E241, E201
|
|
||||||
[-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, 0, 0, 0, 0, 0, 0, 0, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, 0, 0, 0, 0, 0, 0, 0, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, 0, 0, 0, 0, 0, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, 0, 0, 0, 0, 0, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, -1, 0, 0, 0, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, 0, 0, 0, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, 0, 0, 0, 0, 0, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, 0, 0, 0, 0, 0, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, 0, 0, 0, 0, 0, 0, 0, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, 0, 0, 0, 0, 0, 0, 0, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], # noqa: E241, E201
|
||||||
|
[-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], # noqa: E241, E201
|
||||||
|
[-1, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, 0, 0, 0, 0, 0, 0, 0, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241, E201
|
||||||
], dtype=torch.int64, device=device).view(1, 11, 11, 1)
|
], dtype=torch.int64, device=device).view(1, 11, 11, 1)
|
||||||
|
|
||||||
dists_expected = torch.tensor([
|
dists_expected = torch.tensor([
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 0.1283, 0.1071, 0.1071, 0.1071, 0.1071, 0.1071, 0.1283, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., 0.1402, 0.1071, 0.1402, -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 0.1283, 0.0423, 0.0212, 0.0212, 0.0212, 0.0212, 0.0212, 0.0423, 0.1283, -1.0000], # noqa: E241, E201
|
[-1., -1., - 1., 0.1523, 0.0542, 0.0212, 0.0542, 0.1523, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 0.1084, 0.0225, -0.0003, -0.0013, -0.0013, -0.0013, -0.0003, 0.0225, 0.1084, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., 0.0955, 0.0214, -0.0003, 0.0214, 0.0955, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 0.1523, 0.0518, 0.0042, -0.0095, -0.0476, -0.0095, 0.0042, 0.0518, 0.1523, -1.0000], # noqa: E241, E201
|
[-1., -1., 0.1523, 0.0518, 0.0042, -0.0095, 0.0042, 0.0518, 0.1523, -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 0.0955, 0.0214, -0.0003, -0.0320, -0.0003, 0.0214, 0.0955, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., 0.0955, 0.0214, -0.0003, -0.032, -0.0003, 0.0214, 0.0955, -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 0.1523, 0.0518, 0.0042, -0.0095, 0.0042, 0.0518, 0.1523, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 0.1523, 0.0518, 0.0042, -0.0095, -0.0476, -0.0095, 0.0042, 0.0518, 0.1523, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, 0.0955, 0.0214, -0.0003, 0.0214, 0.0955, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 0.1084, 0.0225, -0.0003, -0.0013, -0.0013, -0.0013, -0.0003, 0.0225, 0.1084, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, 0.1523, 0.0542, 0.0212, 0.0542, 0.1523, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 0.1283, 0.0423, 0.0212, 0.0212, 0.0212, 0.0212, 0.0212, 0.0423, 0.1283, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, 0.1402, 0.1071, 0.1402, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., 0.1283, 0.1071, 0.1071, 0.1071, 0.1071, 0.1071, 0.1283, -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241, E201
|
||||||
], dtype=torch.float32, device=device).view(1, 11, 11, 1)
|
], dtype=torch.float32, device=device).view(1, 11, 11, 1)
|
||||||
|
|
||||||
# zbuf and barycentric will be different with perspective correction
|
# zbuf and barycentric will be different with perspective correction
|
||||||
zbuf_f_expected = torch.tensor([
|
zbuf_f_expected = torch.tensor([
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 5.9091, 5.9091, 5.9091, 5.9091, 5.9091, 5.9091, 5.9091, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., 24.0909, 24.0909, 24.0909, -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., 21.8182, 21.8182, 21.8182, 21.8182, 21.8182, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., 19.5455, 19.5455, 19.5455, 19.5455, 19.5455, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, -1.0000], # noqa: E241, E201
|
[-1., -1., 17.2727, 17.2727, 17.2727, 17.2727, 17.2727, 17.2727, 17.2727, -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 15.0000, 15.0000, 15.0000, 15.0000, 15.0000, 15.0000, 15.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., 15., 15., 15., 15., 15., 15., 15., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 17.2727, 17.2727, 17.2727, 17.2727, 17.2727, 17.2727, 17.2727, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, 12.7273, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, 19.5455, 19.5455, 19.5455, 19.5455, 19.5455, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, 10.4545, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, 21.8182, 21.8182, 21.8182, 21.8182, 21.8182, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, 8.1818, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, 24.0909, 24.0909, 24.0909, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., 5.9091, 5.9091, 5.9091, 5.9091, 5.9091, 5.9091, 5.9091, -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
], dtype=torch.float32, device=device).view(1, 11, 11, 1)
|
], dtype=torch.float32, device=device).view(1, 11, 11, 1)
|
||||||
|
|
||||||
zbuf_t_expected = torch.tensor([
|
zbuf_t_expected = torch.tensor([
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 8.3019, 8.3019, 8.3019, 8.3019, 8.3019, 8.3019, 8.3019, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., 33.8461, 33.8462, 33.8462, -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., 24.4444, 24.4444, 24.4444, 24.4444, 24.4444, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., 19.1304, 19.1304, 19.1304, 19.1304, 19.1304, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, -1.0000], # noqa: E241, E201
|
[-1., -1., 15.7143, 15.7143, 15.7143, 15.7143, 15.7143, 15.7143, 15.7143, -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., 13.3333, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333, 13.3333, -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, 15.7143, 15.7143, 15.7143, 15.7143, 15.7143, 15.7143, 15.7143, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, 11.5789, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, 19.1304, 19.1304, 19.1304, 19.1304, 19.1304, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, 10.2326, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, 24.4444, 24.4444, 24.4444, 24.4444, 24.4444, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, 9.1667, -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, 33.8462, 33.8462, 33.8461, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., 8.3019, 8.3019, 8.3019, 8.3019, 8.3019, 8.3019, 8.3019, -1., -1.], # noqa: E241, E201
|
||||||
[-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241, E201
|
||||||
], dtype=torch.float32, device=device).view(1, 11, 11, 1)
|
], dtype=torch.float32, device=device).view(1, 11, 11, 1)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@ -618,9 +622,10 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
def _simple_triangle_raster(self, raster_fn, device, bin_size=None):
|
def _simple_triangle_raster(self, raster_fn, device, bin_size=None):
|
||||||
image_size = 10
|
image_size = 10
|
||||||
|
|
||||||
# Mesh with a single face.
|
# Mesh with a single non-symmetrical face - this will help
|
||||||
|
# check that the XY directions are correctly oriented.
|
||||||
verts0 = torch.tensor(
|
verts0 = torch.tensor(
|
||||||
[[-0.7, -0.4, 0.1], [0.0, 0.6, 0.1], [0.7, -0.4, 0.1]],
|
[[-0.3, -0.4, 0.1], [0.0, 0.6, 0.1], [0.9, -0.4, 0.1]],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
@ -630,7 +635,7 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
verts1 = torch.tensor(
|
verts1 = torch.tensor(
|
||||||
[
|
[
|
||||||
[-0.7, -0.4, 0.1], # noqa: E241, E201
|
[-0.9, -0.2, 0.1], # noqa: E241, E201
|
||||||
[ 0.0, 0.6, 0.1], # noqa: E241, E201
|
[ 0.0, 0.6, 0.1], # noqa: E241, E201
|
||||||
[ 0.7, -0.4, 0.1], # noqa: E241, E201
|
[ 0.7, -0.4, 0.1], # noqa: E241, E201
|
||||||
[-0.7, 0.4, 0.5], # noqa: E241, E201
|
[-0.7, 0.4, 0.5], # noqa: E241, E201
|
||||||
@ -645,6 +650,8 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
[[1, 0, 2], [3, 4, 5]], dtype=torch.int64, device=device
|
[[1, 0, 2], [3, 4, 5]], dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Expected output tensors in the format with axes +X left, +Y up, +Z in
|
||||||
|
# k = 0, closest point.
|
||||||
# fmt off
|
# fmt off
|
||||||
expected_p2face_k0 = torch.tensor(
|
expected_p2face_k0 = torch.tensor(
|
||||||
[
|
[
|
||||||
@ -652,10 +659,10 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, 0, 0, 0, 0, 0, 0, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, 0, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, 0, 0, 0, 0, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, 0, 0, 0, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, 0, 0, 0, 0, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, 0, 0, 0, 0, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, 0, 0, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, 0, 0, 0, 0, 0, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
@ -663,11 +670,11 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
[
|
[
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, 1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, 1, 1, 1, 1, 1, 1, -1, -1], # noqa: E241, E201
|
[-1, -1, 2, 2, 1, 1, 1, 2, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, 1, 1, 1, 1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, 1, 1, 1, 1, 1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, 1, 1, 1, 1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, 1, 1, 1, 1, 1, 1, -1], # noqa: E241, E201
|
||||||
[-1, -1, 2, 2, 1, 1, 2, 2, -1, -1], # noqa: E241, E201
|
[-1, -1, 1, 1, 1, 2, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
@ -677,49 +684,37 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
expected_zbuf_k0 = torch.tensor(
|
expected_zbuf_k0 = torch.tensor(
|
||||||
|
[
|
||||||
[
|
[
|
||||||
[
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, 0.1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, 0.1, 0.1, 0.1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, 0.1, 0.1, 0.1, 0.1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, 0.1, 0.1, 0.1, 0.1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, 0.1, 0.1, 0.1, 0.1, -1, -1, -1], # noqa: E241, E201
|
[-1, 0.1, 0.1, 0.1, 0.1, 0.1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, 0.1, 0.1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, 0.1, 0.1, 0.1, 0.1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, 0.1, 0.1, 0.1, 0.1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, 0.5, 0.5, 0.1, 0.1, 0.5, 0.5, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
|
||||||
],
|
|
||||||
],
|
],
|
||||||
|
[
|
||||||
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, -1, -1, 0.1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, 0.5, 0.5, 0.1, 0.1, 0.1, 0.5, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, 0.1, 0.1, 0.1, 0.1, 0.1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, 0.1, 0.1, 0.1, 0.5, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241, E201
|
||||||
|
]
|
||||||
|
],
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
meshes = Meshes(verts=[verts0, verts1], faces=[faces0, faces1])
|
meshes = Meshes(verts=[verts0, verts1], faces=[faces0, faces1])
|
||||||
if bin_size == -1:
|
|
||||||
# simple python case with no binning
|
|
||||||
p2face, zbuf, bary, pix_dists = raster_fn(
|
|
||||||
meshes, image_size, 0.0, 2
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
p2face, zbuf, bary, pix_dists = raster_fn(
|
|
||||||
meshes, image_size, 0.0, 2, bin_size
|
|
||||||
)
|
|
||||||
# k = 0, closest point.
|
|
||||||
self.assertTrue(torch.allclose(p2face[..., 0], expected_p2face_k0))
|
|
||||||
self.assertTrue(torch.allclose(zbuf[..., 0], expected_zbuf_k0))
|
|
||||||
|
|
||||||
# k = 1, second closest point.
|
# k = 1, second closest point.
|
||||||
expected_p2face_k1 = expected_p2face_k0.clone()
|
expected_p2face_k1 = expected_p2face_k0.clone()
|
||||||
@ -729,18 +724,18 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
expected_p2face_k1[1, :] = torch.tensor(
|
expected_p2face_k1[1, :] = torch.tensor(
|
||||||
[
|
[
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, 2, 2, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, 2, 2, 2, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, 2, 2, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, 2, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241, E201
|
||||||
],
|
],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
@ -748,21 +743,35 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
expected_zbuf_k1[0, :] = torch.ones_like(expected_zbuf_k1[0, :]) * -1
|
expected_zbuf_k1[0, :] = torch.ones_like(expected_zbuf_k1[0, :]) * -1
|
||||||
expected_zbuf_k1[1, :] = torch.tensor(
|
expected_zbuf_k1[1, :] = torch.tensor(
|
||||||
[
|
[
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, 0.5, 0.5, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., 0.5, 0.5, 0.5, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, 0.5, 0.5, 0.5, 0.5, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., 0.5, 0.5, 0.5, 0.5, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, 0.5, 0.5, 0.5, 0.5, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., 0.5, 0.5, 0.5, 0.5, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, 0.5, 0.5, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., 0.5, -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241, E201
|
||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
# Coordinate conventions +Y up, +Z in, +X left
|
||||||
|
if bin_size == -1:
|
||||||
|
# simple python, no bin_size
|
||||||
|
p2face, zbuf, bary, pix_dists = raster_fn(
|
||||||
|
meshes, image_size, 0.0, 2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
p2face, zbuf, bary, pix_dists = raster_fn(
|
||||||
|
meshes, image_size, 0.0, 2, bin_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(p2face[..., 0], expected_p2face_k0))
|
||||||
|
self.assertTrue(torch.allclose(zbuf[..., 0], expected_zbuf_k0))
|
||||||
self.assertTrue(torch.allclose(p2face[..., 1], expected_p2face_k1))
|
self.assertTrue(torch.allclose(p2face[..., 1], expected_p2face_k1))
|
||||||
self.assertTrue(torch.allclose(zbuf[..., 1], expected_zbuf_k1))
|
self.assertTrue(torch.allclose(zbuf[..., 1], expected_zbuf_k1))
|
||||||
|
|
||||||
@ -778,9 +787,9 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
# fmt: off
|
# fmt: off
|
||||||
verts = torch.tensor(
|
verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[ -0.5, 0.0, 0.1], # noqa: E241, E201
|
[ -0.3, 0.0, 0.1], # noqa: E241, E201
|
||||||
[ 0.0, 0.6, 0.1], # noqa: E241, E201
|
[ 0.0, 0.6, 0.1], # noqa: E241, E201
|
||||||
[ 0.5, 0.0, 0.1], # noqa: E241, E201
|
[ 0.8, 0.0, 0.1], # noqa: E241, E201
|
||||||
[-0.25, 0.0, 0.9], # noqa: E241, E201
|
[-0.25, 0.0, 0.9], # noqa: E241, E201
|
||||||
[0.25, 0.5, 0.9], # noqa: E241, E201
|
[0.25, 0.5, 0.9], # noqa: E241, E201
|
||||||
[0.75, 0.0, 0.9], # noqa: E241, E201
|
[0.75, 0.0, 0.9], # noqa: E241, E201
|
||||||
@ -794,6 +803,8 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
# Face with index 0 is non symmetric about the X and Y axis to
|
||||||
|
# test that the positive Y and X directions are correct in the output.
|
||||||
faces_packed = torch.tensor(
|
faces_packed = torch.tensor(
|
||||||
[[1, 0, 2], [4, 3, 5], [7, 6, 8], [10, 9, 11]],
|
[[1, 0, 2], [4, 3, 5], [7, 6, 8], [10, 9, 11]],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
@ -803,12 +814,12 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
[
|
[
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, 2, 2, 0, 0, 0, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, 2, 0, 0, 0, 0, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, 0, 0, 0, 0, 0, 0, -1, -1, -1], # noqa: E241, E201
|
||||||
|
[-1, 0, 0, 0, 0, 0, 0, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, 0, 0, 0, 0, 0, 0, 2, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, 0, 0, 0, 0, 0, 0, 2, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, 0, 0, 0, 0, 2, 2, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, -1, 0, 0, 2, 2, 2, -1], # noqa: E241, E201
|
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
||||||
],
|
],
|
||||||
@ -817,16 +828,16 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
expected_zbuf = torch.tensor(
|
expected_zbuf = torch.tensor(
|
||||||
[
|
[
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., 0.5, 0.5, 0.1, 0.1, 0.1, -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., 0.5, 0.1, 0.1, 0.1, 0.1, -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.5, -1], # noqa: E241, E201
|
[-1., 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.5, -1], # noqa: E241, E201
|
[-1., 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, 0.1, 0.1, 0.1, 0.1, 0.5, 0.5, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, 0.1, 0.1, 0.5, 0.5, 0.5, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.], # noqa: E241, E201
|
||||||
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241, E201
|
[-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.] # noqa: E241, E201
|
||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=device,
|
device=device,
|
||||||
@ -837,7 +848,7 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
faces = faces_packed[order] # rearrange order of faces.
|
faces = faces_packed[order] # rearrange order of faces.
|
||||||
mesh = Meshes(verts=[verts], faces=[faces])
|
mesh = Meshes(verts=[verts], faces=[faces])
|
||||||
if bin_size == -1:
|
if bin_size == -1:
|
||||||
# simple python case with no binning
|
# simple python, no bin size arg
|
||||||
pix_to_face, zbuf, bary_coords, dists = raster_fn(
|
pix_to_face, zbuf, bary_coords, dists = raster_fn(
|
||||||
mesh, image_size, blur_radius, faces_per_pixel
|
mesh, image_size, blur_radius, faces_per_pixel
|
||||||
)
|
)
|
||||||
@ -845,7 +856,6 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
pix_to_face, zbuf, bary_coords, dists = raster_fn(
|
pix_to_face, zbuf, bary_coords, dists = raster_fn(
|
||||||
mesh, image_size, blur_radius, faces_per_pixel, bin_size
|
mesh, image_size, blur_radius, faces_per_pixel, bin_size
|
||||||
)
|
)
|
||||||
|
|
||||||
if i == 0:
|
if i == 0:
|
||||||
expected_dists = dists
|
expected_dists = dists
|
||||||
p2f = expected_p2f.clone()
|
p2f = expected_p2f.clone()
|
||||||
@ -861,33 +871,37 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
def _test_coarse_rasterize(self, device):
|
def _test_coarse_rasterize(self, device):
|
||||||
image_size = 16
|
image_size = 16
|
||||||
blur_radius = 0.2 ** 2
|
# No blurring. This test checks that the XY directions are
|
||||||
|
# correctly oriented.
|
||||||
|
blur_radius = 0.0
|
||||||
bin_size = 8
|
bin_size = 8
|
||||||
max_faces_per_bin = 3
|
max_faces_per_bin = 3
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
verts = torch.tensor(
|
verts = torch.tensor(
|
||||||
[
|
[
|
||||||
[-0.5, 0.0, 0.1], # noqa: E241, E201
|
[-0.5, 0.1, 0.1], # noqa: E241, E201
|
||||||
[ 0.0, 0.6, 0.1], # noqa: E241, E201
|
[-0.3, 0.6, 0.1], # noqa: E241, E201
|
||||||
[ 0.5, 0.0, 0.1], # noqa: E241, E201
|
[-0.1, 0.1, 0.1], # noqa: E241, E201
|
||||||
[-0.3, 0.0, 0.4], # noqa: E241, E201
|
[-0.3, -0.1, 0.4], # noqa: E241, E201
|
||||||
[ 0.3, 0.5, 0.4], # noqa: E241, E201
|
[ 0.3, 0.5, 0.4], # noqa: E241, E201
|
||||||
[0.75, 0.0, 0.4], # noqa: E241, E201
|
[0.75, -0.1, 0.4], # noqa: E241, E201
|
||||||
[-0.4, -0.3, 0.9], # noqa: E241, E201
|
[ 0.2, -0.3, 0.9], # noqa: E241, E201
|
||||||
[ 0.2, -0.7, 0.9], # noqa: E241, E201
|
[ 0.3, -0.7, 0.9], # noqa: E241, E201
|
||||||
[ 0.4, -0.3, 0.9], # noqa: E241, E201
|
[ 0.6, -0.3, 0.9], # noqa: E241, E201
|
||||||
[-0.4, 0.0, -1.5], # noqa: E241, E201
|
[-0.4, 0.0, -1.5], # noqa: E241, E201
|
||||||
[ 0.6, 0.6, -1.5], # noqa: E241, E201
|
[ 0.6, 0.6, -1.5], # noqa: E241, E201
|
||||||
[ 0.8, 0.0, -1.5], # noqa: E241, E201
|
[ 0.8, 0.0, -1.5], # noqa: E241, E201
|
||||||
],
|
],
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
# Expected faces using axes convention +Y down, + X right, +Z in
|
||||||
|
# Non symmetrical triangles i.e face 0 and 3 are in one bin only
|
||||||
faces = torch.tensor(
|
faces = torch.tensor(
|
||||||
[
|
[
|
||||||
[ 1, 0, 2], # noqa: E241, E201 bin 00 and bin 01
|
[ 1, 0, 2], # noqa: E241, E201 bin 01 only
|
||||||
[ 4, 3, 5], # noqa: E241, E201 bin 00 and bin 01
|
[ 4, 3, 5], # noqa: E241, E201 all bins
|
||||||
[ 7, 6, 8], # noqa: E241, E201 bin 10 and bin 11
|
[ 7, 6, 8], # noqa: E241, E201 bin 10 only
|
||||||
[10, 9, 11], # noqa: E241, E201 negative z, should not appear.
|
[10, 9, 11], # noqa: E241, E201 negative z, should not appear.
|
||||||
],
|
],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
@ -900,16 +914,19 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
||||||
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||||
|
|
||||||
|
# Expected faces using axes convention +Y down, + X right, + Z in
|
||||||
bin_faces_expected = (
|
bin_faces_expected = (
|
||||||
torch.ones(
|
torch.ones(
|
||||||
(1, 2, 2, max_faces_per_bin), dtype=torch.int32, device=device
|
(1, 2, 2, max_faces_per_bin), dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
* -1
|
* -1
|
||||||
)
|
)
|
||||||
bin_faces_expected[0, 0, 0, 0:2] = torch.tensor([0, 1])
|
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
|
||||||
|
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([1, 2])
|
||||||
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([0, 1])
|
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([0, 1])
|
||||||
bin_faces_expected[0, 1, 0, 0:3] = torch.tensor([0, 1, 2])
|
bin_faces_expected[0, 1, 1, 0] = torch.tensor([1])
|
||||||
bin_faces_expected[0, 1, 1, 0:3] = torch.tensor([0, 1, 2])
|
|
||||||
|
# +Y up, +X left, +Z in
|
||||||
bin_faces = _C._rasterize_meshes_coarse(
|
bin_faces = _C._rasterize_meshes_coarse(
|
||||||
faces_verts,
|
faces_verts,
|
||||||
mesh_to_face_first_idx,
|
mesh_to_face_first_idx,
|
||||||
@ -919,9 +936,8 @@ class TestRasterizeMeshes(unittest.TestCase):
|
|||||||
bin_size,
|
bin_size,
|
||||||
max_faces_per_bin,
|
max_faces_per_bin,
|
||||||
)
|
)
|
||||||
bin_faces_same = (
|
# Flip x and y axis of output before comparing to expected
|
||||||
bin_faces.squeeze().flip(dims=[0]) == bin_faces_expected
|
bin_faces_same = (bin_faces.squeeze() == bin_faces_expected).all()
|
||||||
).all()
|
|
||||||
self.assertTrue(bin_faces_same.item() == 1)
|
self.assertTrue(bin_faces_same.item() == 1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -72,20 +72,25 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
# Init rasterizer settings
|
# Init rasterizer settings
|
||||||
if elevated_camera:
|
if elevated_camera:
|
||||||
R, T = look_at_view_transform(2.7, 45.0, 0.0)
|
# Elevated and rotated camera
|
||||||
|
R, T = look_at_view_transform(dist=2.7, elev=45.0, azim=45.0)
|
||||||
postfix = "_elevated_camera"
|
postfix = "_elevated_camera"
|
||||||
|
# If y axis is up, the spot of light should
|
||||||
|
# be on the bottom left of the sphere.
|
||||||
else:
|
else:
|
||||||
|
# No elevation or azimuth rotation
|
||||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||||
postfix = ""
|
postfix = ""
|
||||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
raster_settings = RasterizationSettings(
|
|
||||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Init shader settings
|
# Init shader settings
|
||||||
materials = Materials(device=device)
|
materials = Materials(device=device)
|
||||||
lights = PointLights(device=device)
|
lights = PointLights(device=device)
|
||||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||||
|
|
||||||
|
raster_settings = RasterizationSettings(
|
||||||
|
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||||
|
)
|
||||||
|
|
||||||
# Init renderer
|
# Init renderer
|
||||||
rasterizer = MeshRasterizer(
|
rasterizer = MeshRasterizer(
|
||||||
@ -107,15 +112,16 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
# Load reference image
|
# Load reference image
|
||||||
image_ref_phong = load_rgb_image(
|
image_ref_phong = load_rgb_image(
|
||||||
"test_simple_sphere_illuminated%s.png" % postfix
|
"test_simple_sphere_light%s.png" % postfix
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref_phong, atol=0.05))
|
self.assertTrue(torch.allclose(rgb, image_ref_phong, atol=0.05))
|
||||||
|
|
||||||
###################################
|
########################################################
|
||||||
# Move the light behind the object
|
# Move the light to the +z axis in world space so it is
|
||||||
###################################
|
# behind the sphere. Note that +Z is in, +Y up,
|
||||||
# Check the image is dark
|
# +X left for both world and camera space.
|
||||||
lights.location[..., 2] = +2.0
|
########################################################
|
||||||
|
lights.location[..., 2] = -2.0
|
||||||
images = renderer(sphere_mesh, lights=lights)
|
images = renderer(sphere_mesh, lights=lights)
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
@ -133,7 +139,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
######################################
|
######################################
|
||||||
# Change the shader to a GouraudShader
|
# Change the shader to a GouraudShader
|
||||||
######################################
|
######################################
|
||||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||||
renderer = MeshRenderer(
|
renderer = MeshRenderer(
|
||||||
rasterizer=rasterizer,
|
rasterizer=rasterizer,
|
||||||
shader=HardGouraudShader(
|
shader=HardGouraudShader(
|
||||||
@ -143,7 +149,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
images = renderer(sphere_mesh)
|
images = renderer(sphere_mesh)
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
filename = "DEBUG_simple_sphere_light_gourad%s.png" % postfix
|
filename = "DEBUG_simple_sphere_light_gouraud%s.png" % postfix
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
DATA_DIR / filename
|
DATA_DIR / filename
|
||||||
)
|
)
|
||||||
@ -157,7 +163,6 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
######################################
|
######################################
|
||||||
# Change the shader to a HardFlatShader
|
# Change the shader to a HardFlatShader
|
||||||
######################################
|
######################################
|
||||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
|
||||||
renderer = MeshRenderer(
|
renderer = MeshRenderer(
|
||||||
rasterizer=rasterizer,
|
rasterizer=rasterizer,
|
||||||
shader=HardFlatShader(
|
shader=HardFlatShader(
|
||||||
@ -217,7 +222,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
# Init shader settings
|
# Init shader settings
|
||||||
materials = Materials(device=device)
|
materials = Materials(device=device)
|
||||||
lights = PointLights(device=device)
|
lights = PointLights(device=device)
|
||||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||||
|
|
||||||
# Init renderer
|
# Init renderer
|
||||||
renderer = MeshRenderer(
|
renderer = MeshRenderer(
|
||||||
@ -231,7 +236,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
images = renderer(sphere_meshes)
|
images = renderer(sphere_meshes)
|
||||||
|
|
||||||
# Load ref image
|
# Load ref image
|
||||||
image_ref = load_rgb_image("test_simple_sphere_illuminated.png")
|
image_ref = load_rgb_image("test_simple_sphere_light.png")
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
rgb = images[i, ..., :3].squeeze().cpu()
|
rgb = images[i, ..., :3].squeeze().cpu()
|
||||||
@ -261,7 +266,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Init rasterizer settings
|
# Init rasterizer settings
|
||||||
R, T = look_at_view_transform(2.7, 10, 20)
|
R, T = look_at_view_transform(2.7, 0, 0)
|
||||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
|
||||||
# Init renderer
|
# Init renderer
|
||||||
@ -275,7 +280,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
alpha = images[0, ..., 3].squeeze().cpu()
|
alpha = images[0, ..., 3].squeeze().cpu()
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
Image.fromarray((alpha.numpy() * 255).astype(np.uint8)).save(
|
Image.fromarray((alpha.numpy() * 255).astype(np.uint8)).save(
|
||||||
DATA_DIR / "DEBUG_silhouette_grad.png"
|
DATA_DIR / "DEBUG_silhouette.png"
|
||||||
)
|
)
|
||||||
|
|
||||||
with Image.open(image_ref_filename) as raw_image_ref:
|
with Image.open(image_ref_filename) as raw_image_ref:
|
||||||
@ -292,7 +297,8 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
def test_texture_map(self):
|
def test_texture_map(self):
|
||||||
"""
|
"""
|
||||||
Test a mesh with a texture map is loaded and rendered correctly
|
Test a mesh with a texture map is loaded and rendered correctly.
|
||||||
|
The pupils in the eyes of the cow should always be looking to the left.
|
||||||
"""
|
"""
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
DATA_DIR = (
|
DATA_DIR = (
|
||||||
@ -304,7 +310,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
mesh = load_objs_as_meshes([obj_filename], device=device)
|
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||||
|
|
||||||
# Init rasterizer settings
|
# Init rasterizer settings
|
||||||
R, T = look_at_view_transform(2.7, 10, 20)
|
R, T = look_at_view_transform(2.7, 0, 0)
|
||||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
raster_settings = RasterizationSettings(
|
raster_settings = RasterizationSettings(
|
||||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||||
@ -313,7 +319,10 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
# Init shader settings
|
# Init shader settings
|
||||||
materials = Materials(device=device)
|
materials = Materials(device=device)
|
||||||
lights = PointLights(device=device)
|
lights = PointLights(device=device)
|
||||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
|
||||||
|
# Place light behind the cow in world space. The front of
|
||||||
|
# the cow is facing the -z direction.
|
||||||
|
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
|
||||||
|
|
||||||
# Init renderer
|
# Init renderer
|
||||||
renderer = MeshRenderer(
|
renderer = MeshRenderer(
|
||||||
@ -328,18 +337,13 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
|
||||||
# Load reference image
|
# Load reference image
|
||||||
image_ref = load_rgb_image("test_texture_map.png")
|
image_ref = load_rgb_image("test_texture_map_back.png")
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
DATA_DIR / "DEBUG_texture_map.png"
|
DATA_DIR / "DEBUG_texture_map_back.png"
|
||||||
)
|
)
|
||||||
|
|
||||||
# There's a calculation instability on the corner of the ear of the cow.
|
|
||||||
# We ignore that pixel.
|
|
||||||
image_ref[137, 166] = 0
|
|
||||||
rgb[137, 166] = 0
|
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
||||||
|
|
||||||
# Check grad exists
|
# Check grad exists
|
||||||
@ -352,10 +356,31 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
images[0, ...].sum().backward()
|
images[0, ...].sum().backward()
|
||||||
self.assertIsNotNone(verts.grad)
|
self.assertIsNotNone(verts.grad)
|
||||||
|
|
||||||
|
##########################################
|
||||||
|
# Check rendering of the front of the cow
|
||||||
|
##########################################
|
||||||
|
|
||||||
|
R, T = look_at_view_transform(2.7, 0, 180)
|
||||||
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
|
||||||
|
# Move light to the front of the cow in world space
|
||||||
|
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||||
|
images = renderer(mesh, cameras=cameras, lights=lights)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
|
||||||
|
# Load reference image
|
||||||
|
image_ref = load_rgb_image("test_texture_map_front.png")
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / "DEBUG_texture_map_front.png"
|
||||||
|
)
|
||||||
|
|
||||||
#################################
|
#################################
|
||||||
# Add blurring to rasterization
|
# Add blurring to rasterization
|
||||||
#################################
|
#################################
|
||||||
|
R, T = look_at_view_transform(2.7, 0, 180)
|
||||||
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
|
blend_params = BlendParams(sigma=5e-4, gamma=1e-4)
|
||||||
raster_settings = RasterizationSettings(
|
raster_settings = RasterizationSettings(
|
||||||
image_size=512,
|
image_size=512,
|
||||||
@ -366,6 +391,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
images = renderer(
|
images = renderer(
|
||||||
mesh.clone(),
|
mesh.clone(),
|
||||||
|
cameras=cameras,
|
||||||
raster_settings=raster_settings,
|
raster_settings=raster_settings,
|
||||||
blend_params=blend_params,
|
blend_params=blend_params,
|
||||||
)
|
)
|
||||||
|