diff --git a/pytorch3d/csrc/ext.cpp b/pytorch3d/csrc/ext.cpp index 0555afc9..4d3dd4e2 100644 --- a/pytorch3d/csrc/ext.cpp +++ b/pytorch3d/csrc/ext.cpp @@ -10,7 +10,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("face_areas_normals", &FaceAreasNormals); - m.def("packed_to_padded_tensor", &packed_to_padded_tensor); + m.def("packed_to_padded", &PackedToPadded); + m.def("padded_to_packed", &PaddedToPacked); m.def("nn_points_idx", &NearestNeighborIdx); m.def("gather_scatter", &gather_scatter); m.def("rasterize_points", &RasterizePoints); diff --git a/pytorch3d/csrc/face_areas_normals/face_areas_normals.h b/pytorch3d/csrc/face_areas_normals/face_areas_normals.h index 0ef03cc4..28958407 100644 --- a/pytorch3d/csrc/face_areas_normals/face_areas_normals.h +++ b/pytorch3d/csrc/face_areas_normals/face_areas_normals.h @@ -21,10 +21,12 @@ std::tuple FaceAreasNormalsCpu( at::Tensor verts, at::Tensor faces); +#ifdef WITH_CUDA // Cuda implementation. std::tuple FaceAreasNormalsCuda( at::Tensor verts, at::Tensor faces); +#endif // Implementation which is exposed. std::tuple FaceAreasNormals( diff --git a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h index 51c7e72e..99f3a944 100644 --- a/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h +++ b/pytorch3d/csrc/nearest_neighbor_points/nearest_neighbor_points.h @@ -22,8 +22,10 @@ // CPU implementation. at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2); +#ifdef WITH_CUDA // Cuda implementation. at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2); +#endif // Implementation which is exposed. at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) { diff --git a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu index b3fb5f70..e4fb881e 100644 --- a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu +++ b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.cu @@ -1,11 +1,42 @@ // Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. #include +#include +// Kernel for inputs_packed of shape (F, D), where D > 1 template -__global__ void packed_to_padded_tensor_kernel( - const scalar_t* __restrict__ inputs, - const long* __restrict__ first_idxs, +__global__ void PackedToPaddedKernel( + const scalar_t* __restrict__ inputs_packed, + const int64_t* __restrict__ first_idxs, + scalar_t* __restrict__ inputs_padded, + const size_t batch_size, + const size_t max_size, + const size_t num_inputs, + const size_t D) { + // Batch elements split evenly across blocks (num blocks = batch_size) and + // values for each element split across threads in the block. Each thread adds + // the values of its respective input elements to the global inputs_padded + // tensor. + const size_t tid = threadIdx.x; + const size_t batch_idx = blockIdx.x; + + const int64_t start = first_idxs[batch_idx]; + const int64_t end = + batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs; + const int num = end - start; + for (size_t f = tid; f < num; f += blockDim.x) { + for (size_t j = 0; j < D; ++j) { + inputs_padded[batch_idx * max_size * D + f * D + j] = + inputs_packed[(start + f) * D + j]; + } + } +} + +// Kernel for inputs of shape (F, 1) +template +__global__ void PackedToPaddedKernelD1( + const scalar_t* __restrict__ inputs_packed, + const int64_t* __restrict__ first_idxs, scalar_t* __restrict__ inputs_padded, const size_t batch_size, const size_t max_size, @@ -17,36 +48,155 @@ __global__ void packed_to_padded_tensor_kernel( const size_t tid = threadIdx.x; const size_t batch_idx = blockIdx.x; - const long start = first_idxs[batch_idx]; - const long end = + const int64_t start = first_idxs[batch_idx]; + const int64_t end = batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs; - const int num_faces = end - start; - for (size_t f = tid; f < num_faces; f += blockDim.x) { - inputs_padded[batch_idx * max_size + f] = inputs[start + f]; + const int num = end - start; + for (size_t f = tid; f < num; f += blockDim.x) { + inputs_padded[batch_idx * max_size + f] = inputs_packed[start + f]; } } -at::Tensor packed_to_padded_tensor_cuda( - at::Tensor inputs, - at::Tensor first_idxs, - const long max_size) { - const auto num_inputs = inputs.size(0); - const auto batch_size = first_idxs.size(0); +// Kernel for inputs_padded of shape (B, F, D), where D > 1 +template +__global__ void PaddedToPackedKernel( + const scalar_t* __restrict__ inputs_padded, + const int64_t* __restrict__ first_idxs, + scalar_t* __restrict__ inputs_packed, + const size_t batch_size, + const size_t max_size, + const size_t num_inputs, + const size_t D) { + // Batch elements split evenly across blocks (num blocks = batch_size) and + // values for each element split across threads in the block. Each thread adds + // the values of its respective input elements to the global inputs_packed + // tensor. + const size_t tid = threadIdx.x; + const size_t batch_idx = blockIdx.x; + + const int64_t start = first_idxs[batch_idx]; + const int64_t end = + batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs; + const int num = end - start; + for (size_t f = tid; f < num; f += blockDim.x) { + for (size_t j = 0; j < D; ++j) { + inputs_packed[(start + f) * D + j] = + inputs_padded[batch_idx * max_size * D + f * D + j]; + } + } +} + +// Kernel for inputs_padded of shape (B, F, 1) +template +__global__ void PaddedToPackedKernelD1( + const scalar_t* __restrict__ inputs_padded, + const int64_t* __restrict__ first_idxs, + scalar_t* __restrict__ inputs_packed, + const size_t batch_size, + const size_t max_size, + const size_t num_inputs) { + // Batch elements split evenly across blocks (num blocks = batch_size) and + // values for each element split across threads in the block. Each thread adds + // the values of its respective input elements to the global inputs_packed + // tensor. + const size_t tid = threadIdx.x; + const size_t batch_idx = blockIdx.x; + + const int64_t start = first_idxs[batch_idx]; + const int64_t end = + batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs; + const int num = end - start; + for (size_t f = tid; f < num; f += blockDim.x) { + inputs_packed[start + f] = inputs_padded[batch_idx * max_size + f]; + } +} + +at::Tensor PackedToPaddedCuda( + const at::Tensor inputs_packed, + const at::Tensor first_idxs, + const int64_t max_size) { + const int64_t num_inputs = inputs_packed.size(0); + const int64_t batch_size = first_idxs.size(0); + + AT_ASSERTM( + inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor"); + const int64_t D = inputs_packed.size(1); at::Tensor inputs_padded = - at::zeros({batch_size, max_size}, inputs.options()); + at::zeros({batch_size, max_size, D}, inputs_packed.options()); const int threads = 512; const int blocks = batch_size; - AT_DISPATCH_FLOATING_TYPES( - inputs.type(), "packed_to_padded_tensor_kernel", ([&] { - packed_to_padded_tensor_kernel<<>>( - inputs.data_ptr(), - first_idxs.data_ptr(), - inputs_padded.data_ptr(), - batch_size, - max_size, - num_inputs); - })); + if (D == 1) { + AT_DISPATCH_FLOATING_TYPES( + inputs_packed.type(), "packed_to_padded_d1_kernel", ([&] { + PackedToPaddedKernelD1<<>>( + inputs_packed.data_ptr(), + first_idxs.data_ptr(), + inputs_padded.data_ptr(), + batch_size, + max_size, + num_inputs); + })); + } else { + AT_DISPATCH_FLOATING_TYPES( + inputs_packed.type(), "packed_to_padded_kernel", ([&] { + PackedToPaddedKernel<<>>( + inputs_packed.data_ptr(), + first_idxs.data_ptr(), + inputs_padded.data_ptr(), + batch_size, + max_size, + num_inputs, + D); + })); + } return inputs_padded; } + +at::Tensor PaddedToPackedCuda( + const at::Tensor inputs_padded, + const at::Tensor first_idxs, + const int64_t num_inputs) { + const int64_t batch_size = inputs_padded.size(0); + const int64_t max_size = inputs_padded.size(1); + + AT_ASSERTM(batch_size == first_idxs.size(0), "sizes mismatch"); + AT_ASSERTM( + inputs_padded.dim() == 3, + "inputs_padded must be a 3-dimensional tensor"); + const int64_t D = inputs_padded.size(2); + + at::Tensor inputs_packed = + at::zeros({num_inputs, D}, inputs_padded.options()); + + const int threads = 512; + const int blocks = batch_size; + + if (D == 1) { + AT_DISPATCH_FLOATING_TYPES( + inputs_padded.type(), "padded_to_packed_d1_kernel", ([&] { + PaddedToPackedKernelD1<<>>( + inputs_padded.data_ptr(), + first_idxs.data_ptr(), + inputs_packed.data_ptr(), + batch_size, + max_size, + num_inputs); + })); + } else { + AT_DISPATCH_FLOATING_TYPES( + inputs_padded.type(), "padded_to_packed_kernel", ([&] { + PaddedToPackedKernel<<>>( + inputs_padded.data_ptr(), + first_idxs.data_ptr(), + inputs_packed.data_ptr(), + batch_size, + max_size, + num_inputs, + D); + })); + } + + return inputs_packed; +} diff --git a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h index 1edbfe30..f9ef6ed1 100644 --- a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h +++ b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor.h @@ -3,42 +3,96 @@ #pragma once #include +// PackedToPadded // Converts a packed tensor into a padded tensor, restoring the batch dimension. // Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors. // // Inputs: -// inputs: FloatTensor of shape (F,), representing the packed batch tensor. -// e.g. areas for faces in a batch of meshes. +// inputs_packed: FloatTensor of shape (F, D), representing the packed batch +// tensor, e.g. areas for faces in a batch of meshes. // first_idxs: LongTensor of shape (N,) where N is the number of -// elements in the batch and `packed_first_idxs[i] = f` +// elements in the batch and `first_idxs[i] = f` // means that the inputs for batch element i begin at // `inputs[f]`. -// max_size: Max length of an element in the batch. +// max_size: Max length of an element in the batch. // Returns: -// inputs_padded: FloatTensor of shape (N, max_size) where max_size is max +// inputs_padded: FloatTensor of shape (N, max_size, D) where max_size is max // of `sizes`. The values for batch element i which start at -// `inputs[packed_first_idxs[i]]` will be copied to -// `inputs_padded[i, :]``, with zeros padding out the extra +// `inputs_packed[first_idxs[i]]` will be copied to +// `inputs_padded[i, :]`, with zeros padding out the extra // inputs. // +// PaddedToPacked +// Converts a padded tensor into a packed tensor. +// Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors. +// +// Inputs: +// inputs_padded: FloatTensor of shape (N, max_size, D), representing the +// padded tensor, e.g. areas for faces in a batch of meshes. +// first_idxs: LongTensor of shape (N,) where N is the number of +// elements in the batch and `first_idxs[i] = f` +// means that the inputs for batch element i begin at +// `inputs_packed[f]`. +// num_inputs: Number of packed entries (= F) +// Returns: +// inputs_packed: FloatTensor of shape (F, D), where +// `inputs_packed[first_idx[i]:] = inputs_padded[i, :]`. +// +// + +// Cpu implementation. +at::Tensor PackedToPaddedCpu( + const at::Tensor inputs_packed, + const at::Tensor first_idxs, + const int64_t max_size); + +// Cpu implementation. +at::Tensor PaddedToPackedCpu( + const at::Tensor inputs_padded, + const at::Tensor first_idxs, + const int64_t num_inputs); + +#ifdef WITH_CUDA // Cuda implementation. -at::Tensor packed_to_padded_tensor_cuda( - at::Tensor inputs, - at::Tensor first_idxs, - const long max_size); +at::Tensor PackedToPaddedCuda( + const at::Tensor inputs_packed, + const at::Tensor first_idxs, + const int64_t max_size); + +// Cuda implementation. +at::Tensor PaddedToPackedCuda( + const at::Tensor inputs_padded, + const at::Tensor first_idxs, + const int64_t num_inputs); +#endif // Implementation which is exposed. -at::Tensor packed_to_padded_tensor( - at::Tensor inputs, - at::Tensor first_idxs, - const long max_size) { - if (inputs.type().is_cuda()) { +at::Tensor PackedToPadded( + const at::Tensor inputs_packed, + const at::Tensor first_idxs, + const int64_t max_size) { + if (inputs_packed.type().is_cuda()) { #ifdef WITH_CUDA - return packed_to_padded_tensor_cuda(inputs, first_idxs, max_size); + return PackedToPaddedCuda(inputs_packed, first_idxs, max_size); #else AT_ERROR("Not compiled with GPU support."); #endif } - AT_ERROR("Not implemented on the CPU."); + return PackedToPaddedCpu(inputs_packed, first_idxs, max_size); +} + +// Implementation which is exposed. +at::Tensor PaddedToPacked( + const at::Tensor inputs_padded, + const at::Tensor first_idxs, + const int64_t num_inputs) { + if (inputs_padded.type().is_cuda()) { +#ifdef WITH_CUDA + return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs); } diff --git a/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor_cpu.cpp b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor_cpu.cpp new file mode 100644 index 00000000..dd872b78 --- /dev/null +++ b/pytorch3d/csrc/packed_to_padded_tensor/packed_to_padded_tensor_cpu.cpp @@ -0,0 +1,65 @@ +// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +#include + +at::Tensor PackedToPaddedCpu( + const at::Tensor inputs_packed, + const at::Tensor first_idxs, + const int64_t max_size) { + const int64_t num_inputs = inputs_packed.size(0); + const int64_t batch_size = first_idxs.size(0); + + AT_ASSERTM( + inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor"); + const int64_t D = inputs_packed.size(1); + + torch::Tensor inputs_padded = + torch::zeros({batch_size, max_size, D}, inputs_packed.options()); + + auto inputs_packed_a = inputs_packed.accessor(); + auto first_idxs_a = first_idxs.accessor(); + auto inputs_padded_a = inputs_padded.accessor(); + + for (int b = 0; b < batch_size; ++b) { + const int64_t start = first_idxs_a[b]; + const int64_t end = b + 1 < batch_size ? first_idxs_a[b + 1] : num_inputs; + const int64_t num = end - start; + for (int i = 0; i < num; ++i) { + for (int j = 0; j < D; ++j) { + inputs_padded_a[b][i][j] = inputs_packed_a[start + i][j]; + } + } + } + return inputs_padded; +} + +at::Tensor PaddedToPackedCpu( + const at::Tensor inputs_padded, + const at::Tensor first_idxs, + const int64_t num_inputs) { + const int64_t batch_size = inputs_padded.size(0); + const int64_t max_size = inputs_padded.size(1); + + AT_ASSERTM( + inputs_padded.dim() == 3, "inputs_padded must be a 3-dimensional tensor"); + const int64_t D = inputs_padded.size(2); + + torch::Tensor inputs_packed = + torch::zeros({num_inputs, D}, inputs_padded.options()); + + auto inputs_padded_a = inputs_padded.accessor(); + auto first_idxs_a = first_idxs.accessor(); + auto inputs_packed_a = inputs_packed.accessor(); + + for (int b = 0; b < batch_size; ++b) { + const int64_t start = first_idxs_a[b]; + const int64_t end = b + 1 < batch_size ? first_idxs_a[b + 1] : num_inputs; + const int64_t num = end - start; + for (int i = 0; i < num; ++i) { + for (int j = 0; j < D; ++j) { + inputs_packed_a[start + i][j] = inputs_padded_a[b][i][j]; + } + } + } + return inputs_packed; +} diff --git a/pytorch3d/ops/__init__.py b/pytorch3d/ops/__init__.py index 9cd3b4dc..0a0fe2f5 100644 --- a/pytorch3d/ops/__init__.py +++ b/pytorch3d/ops/__init__.py @@ -4,6 +4,7 @@ from .cubify import cubify from .graph_conv import GraphConv from .nearest_neighbor_points import nn_points_idx +from .packed_to_padded import packed_to_padded, padded_to_packed from .sample_points_from_meshes import sample_points_from_meshes from .subdivide_meshes import SubdivideMeshes from .vert_align import vert_align diff --git a/pytorch3d/ops/packed_to_padded.py b/pytorch3d/ops/packed_to_padded.py new file mode 100644 index 00000000..c64894b7 --- /dev/null +++ b/pytorch3d/ops/packed_to_padded.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +from pytorch3d import _C + + +class _PackedToPadded(Function): + """ + Torch autograd Function wrapper for packed_to_padded C++/CUDA implementations. + """ + + @staticmethod + def forward(ctx, inputs, first_idxs, max_size): + """ + Args: + ctx: Context object used to calculate gradients. + inputs: FloatTensor of shape (F, D), representing the packed batch tensor. + e.g. areas for faces in a batch of meshes. + first_idxs: LongTensor of shape (N,) where N is the number of + elements in the batch and `first_idxs[i] = f` + means that the inputs for batch element i begin at `inputs[f]`. + max_size: Max length of an element in the batch. + + Returns: + inputs_padded: FloatTensor of shape (N, max_size, D) where max_size is max + of `sizes`. The values for batch element i which start at + `inputs[first_idxs[i]]` will be copied to `inputs_padded[i, :]`, + with zeros padding out the extra inputs. + """ + if not (inputs.dim() == 2): + raise ValueError("input can only be 2-dimensional.") + if not (first_idxs.dim() == 1): + raise ValueError("first_idxs can only be 1-dimensional.") + if not (inputs.dtype == torch.float32): + raise ValueError("input has to be of type torch.float32.") + if not (first_idxs.dtype == torch.int64): + raise ValueError("first_idxs has to be of type torch.int64.") + if not isinstance(max_size, int): + raise ValueError("max_size has to be int.") + + ctx.save_for_backward(first_idxs) + ctx.num_inputs = int(inputs.shape[0]) + inputs, first_idxs = inputs.contiguous(), first_idxs.contiguous() + inputs_padded = _C.packed_to_padded(inputs, first_idxs, max_size) + return inputs_padded + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + first_idxs = ctx.saved_tensors[0] + num_inputs = ctx.num_inputs + grad_input = _C.padded_to_packed(grad_output, first_idxs, num_inputs) + return grad_input, None, None + + +def packed_to_padded(inputs, first_idxs, max_size): + """ + Torch wrapper that handles allowed input shapes. See description below. + + Args: + inputs: FloatTensor of shape (F,) or (F, D), representing the packed batch tensor. + e.g. areas for faces in a batch of meshes. + first_idxs: LongTensor of shape (N,) where N is the number of + elements in the batch and `first_idxs[i] = f` + means that the inputs for batch element i begin at `inputs[f]`. + max_size: Max length of an element in the batch. + + Returns: + inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, D) where max_size is + max of `sizes`. The values for batch element i which start at + `inputs[first_idxs[i]]` will be copied to `inputs_padded[i, :]`, + with zeros padding out the extra inputs. + + To handle the allowed input shapes, we convert the inputs tensor of shape (F,) to (F, 1). + We reshape the output back to (N, max_size) from (N, max_size, 1). + """ + # if inputs is of shape (F,), reshape into (F, 1) + flat = False + if inputs.dim() == 1: + flat = True + inputs = inputs.unsqueeze(1) + inputs_padded = _PackedToPadded.apply(inputs, first_idxs, max_size) + # if flat is True, reshape output to (N, max_size) from (N, max_size, 1) + if flat: + inputs_padded = inputs_padded.squeeze(2) + return inputs_padded + + +class _PaddedToPacked(Function): + """ + Torch autograd Function wrapper for padded_to_packed C++/CUDA implementations. + """ + + @staticmethod + def forward(ctx, inputs, first_idxs, num_inputs): + """ + Args: + ctx: Context object used to calculate gradients. + inputs: FloatTensor of shape (N, max_size, D), representing the padded tensor. + e.g. areas for faces in a batch of meshes. + first_idxs: LongTensor of shape (N,) where N is the number of + elements in the batch and `first_idxs[i] = f` + means that the inputs for batch element i begin at `inputs_packed[f]`. + num_inputs: Number of packed entries (= F) + + Returns: + inputs_packed: FloatTensor of shape (F, D) where + `inputs_packed[first_idx[i]:] = inputs[i, :]`. + """ + if not (inputs.dim() == 3): + raise ValueError("input can only be 3-dimensional.") + if not (first_idxs.dim() == 1): + raise ValueError("first_idxs can only be 1-dimensional.") + if not (inputs.dtype == torch.float32): + raise ValueError("input has to be of type torch.float32.") + if not (first_idxs.dtype == torch.int64): + raise ValueError("first_idxs has to be of type torch.int64.") + if not isinstance(num_inputs, int): + raise ValueError("max_size has to be int.") + + ctx.save_for_backward(first_idxs) + ctx.max_size = inputs.shape[1] + inputs, first_idxs = inputs.contiguous(), first_idxs.contiguous() + inputs_packed = _C.padded_to_packed(inputs, first_idxs, num_inputs) + return inputs_packed + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + first_idxs = ctx.saved_tensors[0] + max_size = ctx.max_size + grad_input = _C.packed_to_padded(grad_output, first_idxs, max_size) + return grad_input, None, None + + +def padded_to_packed(inputs, first_idxs, num_inputs): + """ + Torch wrapper that handles allowed input shapes. See description below. + + Args: + inputs: FloatTensor of shape (N, max_size) or (N, max_size, D), representing the + padded tensor. e.g. areas for faces in a batch of meshes. + first_idxs: LongTensor of shape (N,) where N is the number of + elements in the batch and `first_idxs[i] = f` + means that the inputs for batch element i begin at `inputs_packed[f]`. + num_inputs: Number of packed entries (= F) + + Returns: + inputs_packed: FloatTensor of shape (F,) or (F, D) where + `inputs_packed[first_idx[i]:] = inputs[i, :]`. + + To handle the allowed input shapes, we convert the inputs tensor of shape (N, max_size) + to (N, max_size, 1). We reshape the output back to (F,) from (F, 1). + """ + # if inputs is of shape (N, max_size), reshape into (N, max_size, 1)) + flat = False + if inputs.dim() == 2: + flat = True + inputs = inputs.unsqueeze(2) + inputs_packed = _PaddedToPacked.apply(inputs, first_idxs, num_inputs) + # if flat is True, reshape output to (F,) from (F, 1) + if flat: + inputs_packed = inputs_packed.squeeze(1) + return inputs_packed diff --git a/pytorch3d/ops/sample_points_from_meshes.py b/pytorch3d/ops/sample_points_from_meshes.py index 0fab7830..abe0f25c 100644 --- a/pytorch3d/ops/sample_points_from_meshes.py +++ b/pytorch3d/ops/sample_points_from_meshes.py @@ -12,6 +12,8 @@ import torch from pytorch3d import _C +from .packed_to_padded import packed_to_padded + def sample_points_from_meshes( meshes, num_samples: int = 10000, return_normals: bool = False @@ -55,7 +57,7 @@ def sample_points_from_meshes( verts, faces ) # Face areas can be zero. max_faces = meshes.num_faces_per_mesh().max().item() - areas_padded = _C.packed_to_padded_tensor( + areas_padded = packed_to_padded( areas, mesh_to_face[meshes.valid], max_faces ) # (N, F) diff --git a/tests/bm_packed_to_padded.py b/tests/bm_packed_to_padded.py new file mode 100644 index 00000000..a83af204 --- /dev/null +++ b/tests/bm_packed_to_padded.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + + +from itertools import product +import torch +from fvcore.common.benchmark import benchmark + +from test_packed_to_padded import TestPackedToPadded + + +def bm_packed_to_padded() -> None: + kwargs_list = [] + backend = ["cpu"] + if torch.cuda.is_available(): + backend.append("cuda:0") + + num_meshes = [2, 10, 32] + num_verts = [100, 1000] + num_faces = [300, 3000] + num_ds = [0, 1, 16] + + test_cases = product(num_meshes, num_verts, num_faces, num_ds, backend) + for case in test_cases: + n, v, f, d, b = case + kwargs_list.append( + { + "num_meshes": n, + "num_verts": v, + "num_faces": f, + "num_d": d, + "device": b, + } + ) + benchmark( + TestPackedToPadded.packed_to_padded_with_init, + "PACKED_TO_PADDED", + kwargs_list, + warmup_iters=1, + ) + + benchmark( + TestPackedToPadded.packed_to_padded_with_init_torch, + "PACKED_TO_PADDED_TORCH", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/bm_sample_points_from_meshes.py b/tests/bm_sample_points_from_meshes.py index 4939717e..859be1d8 100644 --- a/tests/bm_sample_points_from_meshes.py +++ b/tests/bm_sample_points_from_meshes.py @@ -10,56 +10,30 @@ from test_sample_points_from_meshes import TestSamplePoints def bm_sample_points() -> None: - if torch.cuda.is_available(): - device = "cuda:0" - kwargs_list = [] - num_meshes = [2, 10, 32] - num_verts = [100, 1000] - num_faces = [300, 3000] - num_samples = [5000, 10000] - test_cases = product(num_meshes, num_verts, num_faces, num_samples) - for case in test_cases: - n, v, f, s = case - kwargs_list.append( - { - "num_meshes": n, - "num_verts": v, - "num_faces": f, - "num_samples": s, - "device": device, - } - ) - benchmark( - TestSamplePoints.sample_points_with_init, - "SAMPLE_MESH", - kwargs_list, - warmup_iters=1, - ) + backend = ["cpu"] + if torch.cuda.is_available(): + backend.append("cuda:0") kwargs_list = [] - backend_cuda = ["False"] - if torch.cuda.is_available(): - backend_cuda.append("True") - num_meshes = [2, 10, 32] num_verts = [100, 1000] num_faces = [300, 3000] - - test_cases = product(num_meshes, num_verts, num_faces, backend_cuda) + num_samples = [5000, 10000] + test_cases = product(num_meshes, num_verts, num_faces, num_samples, backend) for case in test_cases: - n, v, f, c = case + n, v, f, s, b = case kwargs_list.append( - {"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c} + { + "num_meshes": n, + "num_verts": v, + "num_faces": f, + "num_samples": s, + "device": b, + } ) benchmark( - TestSamplePoints.face_areas_with_init, - "FACE_AREAS", - kwargs_list, - warmup_iters=1, - ) - benchmark( - TestSamplePoints.packed_to_padded_with_init, - "PACKED_TO_PADDED", + TestSamplePoints.sample_points_with_init, + "SAMPLE_MESH", kwargs_list, warmup_iters=1, ) diff --git a/tests/test_packed_to_padded.py b/tests/test_packed_to_padded.py new file mode 100644 index 00000000..a5260e10 --- /dev/null +++ b/tests/test_packed_to_padded.py @@ -0,0 +1,296 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest +import torch + +from pytorch3d.ops import packed_to_padded, padded_to_packed +from pytorch3d.structures.meshes import Meshes + +from common_testing import TestCaseMixin + + +class TestPackedToPadded(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + super().setUp() + torch.manual_seed(1) + + @staticmethod + def init_meshes( + num_meshes: int = 10, + num_verts: int = 1000, + num_faces: int = 3000, + device: str = "cpu", + ): + device = torch.device(device) + verts_list = [] + faces_list = [] + for _ in range(num_meshes): + verts = torch.rand( + (num_verts, 3), dtype=torch.float32, device=device + ) + faces = torch.randint( + num_verts, size=(num_faces, 3), dtype=torch.int64, device=device + ) + verts_list.append(verts) + faces_list.append(faces) + meshes = Meshes(verts_list, faces_list) + + return meshes + + @staticmethod + def packed_to_padded_python(inputs, first_idxs, max_size, device): + """ + PyTorch implementation of packed_to_padded function. + """ + num_meshes = first_idxs.size(0) + D = inputs.shape[1] if inputs.dim() == 2 else 0 + if D == 0: + inputs_padded = torch.zeros((num_meshes, max_size), device=device) + else: + inputs_padded = torch.zeros( + (num_meshes, max_size, D), device=device + ) + for m in range(num_meshes): + s = first_idxs[m] + if m == num_meshes - 1: + f = inputs.shape[0] + else: + f = first_idxs[m + 1] + inputs_padded[m, :f] = inputs[s:f] + + return inputs_padded + + @staticmethod + def padded_to_packed_python(inputs, first_idxs, num_inputs, device): + """ + PyTorch implementation of padded_to_packed function. + """ + num_meshes = inputs.size(0) + D = inputs.shape[2] if inputs.dim() == 3 else 0 + if D == 0: + inputs_packed = torch.zeros((num_inputs,), device=device) + else: + inputs_packed = torch.zeros((num_inputs, D), device=device) + for m in range(num_meshes): + s = first_idxs[m] + if m == num_meshes - 1: + f = num_inputs + else: + f = first_idxs[m + 1] + inputs_packed[s:f] = inputs[m, :f] + + return inputs_packed + + def _test_packed_to_padded_helper(self, D, device): + """ + Check the results from packed_to_padded and PyTorch implementations + are the same. + """ + meshes = self.init_meshes(16, 100, 300, device=device) + faces = meshes.faces_packed() + mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() + max_faces = meshes.num_faces_per_mesh().max().item() + + if D == 0: + values = torch.rand( + (faces.shape[0],), device=device, requires_grad=True + ) + else: + values = torch.rand( + (faces.shape[0], D), device=device, requires_grad=True + ) + values_torch = values.detach().clone() + values_torch.requires_grad = True + values_padded = packed_to_padded( + values, mesh_to_faces_packed_first_idx, max_faces + ) + values_padded_torch = TestPackedToPadded.packed_to_padded_python( + values_torch, mesh_to_faces_packed_first_idx, max_faces, device + ) + # check forward + self.assertClose(values_padded, values_padded_torch) + + # check backward + if D == 0: + grad_inputs = torch.rand((len(meshes), max_faces), device=device) + else: + grad_inputs = torch.rand((len(meshes), max_faces, D), device=device) + values_padded.backward(grad_inputs) + grad_outputs = values.grad + values_padded_torch.backward(grad_inputs) + grad_outputs_torch1 = values_torch.grad + grad_outputs_torch2 = TestPackedToPadded.padded_to_packed_python( + grad_inputs, + mesh_to_faces_packed_first_idx, + values.size(0), + device=device, + ) + self.assertClose(grad_outputs, grad_outputs_torch1) + self.assertClose(grad_outputs, grad_outputs_torch2) + + def test_packed_to_padded_flat_cpu(self): + self._test_packed_to_padded_helper(0, "cpu") + + def test_packed_to_padded_D1_cpu(self): + self._test_packed_to_padded_helper(1, "cpu") + + def test_packed_to_padded_D16_cpu(self): + self._test_packed_to_padded_helper(16, "cpu") + + def test_packed_to_padded_flat_cuda(self): + self._test_packed_to_padded_helper(0, "cuda:0") + + def test_packed_to_padded_D1_cuda(self): + self._test_packed_to_padded_helper(1, "cuda:0") + + def test_packed_to_padded_D16_cuda(self): + self._test_packed_to_padded_helper(16, "cuda:0") + + def _test_padded_to_packed_helper(self, D, device): + """ + Check the results from packed_to_padded and PyTorch implementations + are the same. + """ + meshes = self.init_meshes(16, 100, 300, device=device) + mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() + num_faces_per_mesh = meshes.num_faces_per_mesh() + max_faces = num_faces_per_mesh.max().item() + if D == 0: + values = torch.rand((len(meshes), max_faces), device=device) + else: + values = torch.rand((len(meshes), max_faces, D), device=device) + for i, num in enumerate(num_faces_per_mesh): + values[i, num:] = 0 + values.requires_grad = True + values_torch = values.detach().clone() + values_torch.requires_grad = True + values_packed = padded_to_packed( + values, + mesh_to_faces_packed_first_idx, + num_faces_per_mesh.sum().item(), + ) + values_packed_torch = TestPackedToPadded.padded_to_packed_python( + values_torch, + mesh_to_faces_packed_first_idx, + num_faces_per_mesh.sum().item(), + device, + ) + # check forward + self.assertClose(values_packed, values_packed_torch) + + # check backward + if D == 0: + grad_inputs = torch.rand( + (num_faces_per_mesh.sum().item()), device=device + ) + else: + grad_inputs = torch.rand( + (num_faces_per_mesh.sum().item(), D), device=device + ) + values_packed.backward(grad_inputs) + grad_outputs = values.grad + values_packed_torch.backward(grad_inputs) + grad_outputs_torch1 = values_torch.grad + grad_outputs_torch2 = TestPackedToPadded.packed_to_padded_python( + grad_inputs, + mesh_to_faces_packed_first_idx, + values.size(1), + device=device, + ) + self.assertClose(grad_outputs, grad_outputs_torch1) + self.assertClose(grad_outputs, grad_outputs_torch2) + + def test_padded_to_packed_flat_cpu(self): + self._test_padded_to_packed_helper(0, "cpu") + + def test_padded_to_packed_D1_cpu(self): + self._test_padded_to_packed_helper(1, "cpu") + + def test_padded_to_packed_D16_cpu(self): + self._test_padded_to_packed_helper(16, "cpu") + + def test_padded_to_packed_flat_cuda(self): + self._test_padded_to_packed_helper(0, "cuda:0") + + def test_padded_to_packed_D1_cuda(self): + self._test_padded_to_packed_helper(1, "cuda:0") + + def test_padded_to_packed_D16_cuda(self): + self._test_padded_to_packed_helper(16, "cuda:0") + + def test_invalid_inputs_shapes(self, device="cuda:0"): + with self.assertRaisesRegex( + ValueError, "input can only be 2-dimensional." + ): + values = torch.rand((100, 50, 2), device=device) + first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) + packed_to_padded(values, first_idxs, 100) + + with self.assertRaisesRegex( + ValueError, "input can only be 3-dimensional." + ): + values = torch.rand((100,), device=device) + first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) + padded_to_packed(values, first_idxs, 20) + + with self.assertRaisesRegex( + ValueError, "input can only be 3-dimensional." + ): + values = torch.rand((100, 50, 2, 2), device=device) + first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device) + padded_to_packed(values, first_idxs, 20) + + @staticmethod + def packed_to_padded_with_init( + num_meshes: int, + num_verts: int, + num_faces: int, + num_d: int, + device: str = "cpu", + ): + meshes = TestPackedToPadded.init_meshes( + num_meshes, num_verts, num_faces, device + ) + faces = meshes.faces_packed() + mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() + max_faces = meshes.num_faces_per_mesh().max().item() + if num_d == 0: + values = torch.rand((faces.shape[0],), device=meshes.device) + else: + values = torch.rand((faces.shape[0], num_d), device=meshes.device) + torch.cuda.synchronize() + + def out(): + packed_to_padded(values, mesh_to_faces_packed_first_idx, max_faces) + torch.cuda.synchronize() + + return out + + @staticmethod + def packed_to_padded_with_init_torch( + num_meshes: int, + num_verts: int, + num_faces: int, + num_d: int, + device: str = "cpu", + ): + meshes = TestPackedToPadded.init_meshes( + num_meshes, num_verts, num_faces, device + ) + faces = meshes.faces_packed() + mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() + max_faces = meshes.num_faces_per_mesh().max().item() + if num_d == 0: + values = torch.rand((faces.shape[0],), device=meshes.device) + else: + values = torch.rand((faces.shape[0], num_d), device=meshes.device) + torch.cuda.synchronize() + + def out(): + TestPackedToPadded.packed_to_padded_python( + values, mesh_to_faces_packed_first_idx, max_faces, device + ) + torch.cuda.synchronize() + + return out diff --git a/tests/test_sample_points_from_meshes.py b/tests/test_sample_points_from_meshes.py index d210731b..90758124 100644 --- a/tests/test_sample_points_from_meshes.py +++ b/tests/test_sample_points_from_meshes.py @@ -294,48 +294,6 @@ class TestSamplePoints(unittest.TestCase): return False return True - @staticmethod - def packed_to_padded_tensor(inputs, first_idxs, max_size): - """ - PyTorch implementation of cuda packed_to_padded_tensor function. - """ - num_meshes = first_idxs.size(0) - inputs_padded = torch.zeros((num_meshes, max_size)) - for m in range(num_meshes): - s = first_idxs[m] - if m == num_meshes - 1: - f = inputs.size(0) - else: - f = first_idxs[m + 1] - inputs_padded[m, :f] = inputs[s:f] - - return inputs_padded - - def test_packed_to_padded_tensor(self): - """ - Check the results from packed_to_padded cuda and PyTorch implementions - are the same. - """ - meshes = self.init_meshes(1, 3, 5, device="cuda:0") - verts = meshes.verts_packed() - faces = meshes.faces_packed() - mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() - max_faces = meshes.num_faces_per_mesh().max().item() - - areas, _ = _C.face_areas_normals(verts, faces) - areas_padded = _C.packed_to_padded_tensor( - areas, mesh_to_faces_packed_first_idx, max_faces - ).cpu() - areas_padded_cpu = TestSamplePoints.packed_to_padded_tensor( - areas, mesh_to_faces_packed_first_idx, max_faces - ) - self.assertTrue(torch.allclose(areas_padded, areas_padded_cpu)) - with self.assertRaises(Exception) as err: - _C.packed_to_padded_tensor( - areas.cpu(), mesh_to_faces_packed_first_idx, max_faces - ) - self.assertTrue("Not implemented on the CPU" in str(err.exception)) - @staticmethod def sample_points_with_init( num_meshes: int, @@ -344,7 +302,6 @@ class TestSamplePoints(unittest.TestCase): num_samples: int, device: str = "cpu", ): - device = torch.device(device) verts_list = [] faces_list = [] for _ in range(num_meshes): @@ -366,32 +323,3 @@ class TestSamplePoints(unittest.TestCase): torch.cuda.synchronize() return sample_points - - @staticmethod - def packed_to_padded_with_init( - num_meshes: int, num_verts: int, num_faces: int, cuda: str = True - ): - device = "cuda" if cuda else "cpu" - meshes = TestSamplePoints.init_meshes( - num_meshes, num_verts, num_faces, device - ) - verts = meshes.verts_packed() - faces = meshes.faces_packed() - mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx() - max_faces = meshes.num_faces_per_mesh().max().item() - - areas, _ = _C.face_areas_normals(verts, faces) - torch.cuda.synchronize() - - def packed_to_padded(): - if cuda: - _C.packed_to_padded_tensor( - areas, mesh_to_faces_packed_first_idx, max_faces - ) - else: - TestSamplePoints.packed_to_padded_tensor( - areas, mesh_to_faces_packed_first_idx, max_faces - ) - torch.cuda.synchronize() - - return packed_to_padded