mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
cpp support for packed to padded
Summary: Cpu implementation for packed to padded and added gradients ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- PACKED_TO_PADDED_2_100_300_1_cpu 138 221 3625 PACKED_TO_PADDED_2_100_300_1_cuda:0 184 261 2716 PACKED_TO_PADDED_2_100_300_16_cpu 555 726 901 PACKED_TO_PADDED_2_100_300_16_cuda:0 179 260 2794 PACKED_TO_PADDED_2_100_3000_1_cpu 396 519 1262 PACKED_TO_PADDED_2_100_3000_1_cuda:0 181 274 2764 PACKED_TO_PADDED_2_100_3000_16_cpu 4517 5003 111 PACKED_TO_PADDED_2_100_3000_16_cuda:0 224 397 2235 PACKED_TO_PADDED_2_1000_300_1_cpu 138 212 3616 PACKED_TO_PADDED_2_1000_300_1_cuda:0 180 282 2775 PACKED_TO_PADDED_2_1000_300_16_cpu 565 711 885 PACKED_TO_PADDED_2_1000_300_16_cuda:0 179 264 2797 PACKED_TO_PADDED_2_1000_3000_1_cpu 389 494 1287 PACKED_TO_PADDED_2_1000_3000_1_cuda:0 180 271 2777 PACKED_TO_PADDED_2_1000_3000_16_cpu 4522 5170 111 PACKED_TO_PADDED_2_1000_3000_16_cuda:0 216 286 2313 PACKED_TO_PADDED_10_100_300_1_cpu 251 345 1995 PACKED_TO_PADDED_10_100_300_1_cuda:0 178 262 2806 PACKED_TO_PADDED_10_100_300_16_cpu 2354 2750 213 PACKED_TO_PADDED_10_100_300_16_cuda:0 178 291 2814 PACKED_TO_PADDED_10_100_3000_1_cpu 1519 1786 330 PACKED_TO_PADDED_10_100_3000_1_cuda:0 179 237 2791 PACKED_TO_PADDED_10_100_3000_16_cpu 24705 25879 21 PACKED_TO_PADDED_10_100_3000_16_cuda:0 228 316 2191 PACKED_TO_PADDED_10_1000_300_1_cpu 261 432 1919 PACKED_TO_PADDED_10_1000_300_1_cuda:0 181 261 2756 PACKED_TO_PADDED_10_1000_300_16_cpu 2349 2770 213 PACKED_TO_PADDED_10_1000_300_16_cuda:0 180 256 2782 PACKED_TO_PADDED_10_1000_3000_1_cpu 1613 1929 310 PACKED_TO_PADDED_10_1000_3000_1_cuda:0 183 253 2739 PACKED_TO_PADDED_10_1000_3000_16_cpu 22041 23653 23 PACKED_TO_PADDED_10_1000_3000_16_cuda:0 220 343 2270 PACKED_TO_PADDED_32_100_300_1_cpu 555 750 901 PACKED_TO_PADDED_32_100_300_1_cuda:0 188 282 2661 PACKED_TO_PADDED_32_100_300_16_cpu 7550 8131 67 PACKED_TO_PADDED_32_100_300_16_cuda:0 181 272 2770 PACKED_TO_PADDED_32_100_3000_1_cpu 4574 6327 110 PACKED_TO_PADDED_32_100_3000_1_cuda:0 173 254 2884 PACKED_TO_PADDED_32_100_3000_16_cpu 70366 72563 8 PACKED_TO_PADDED_32_100_3000_16_cuda:0 349 654 1433 PACKED_TO_PADDED_32_1000_300_1_cpu 612 728 818 PACKED_TO_PADDED_32_1000_300_1_cuda:0 189 295 2647 PACKED_TO_PADDED_32_1000_300_16_cpu 7699 8254 65 PACKED_TO_PADDED_32_1000_300_16_cuda:0 189 311 2646 PACKED_TO_PADDED_32_1000_3000_1_cpu 5105 5261 98 PACKED_TO_PADDED_32_1000_3000_1_cuda:0 191 260 2625 PACKED_TO_PADDED_32_1000_3000_16_cpu 87073 92708 6 PACKED_TO_PADDED_32_1000_3000_16_cuda:0 344 425 1455 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- PACKED_TO_PADDED_TORCH_2_100_300_1_cpu 492 627 1016 PACKED_TO_PADDED_TORCH_2_100_300_1_cuda:0 768 975 652 PACKED_TO_PADDED_TORCH_2_100_300_16_cpu 659 804 760 PACKED_TO_PADDED_TORCH_2_100_300_16_cuda:0 781 918 641 PACKED_TO_PADDED_TORCH_2_100_3000_1_cpu 624 734 802 PACKED_TO_PADDED_TORCH_2_100_3000_1_cuda:0 778 929 643 PACKED_TO_PADDED_TORCH_2_100_3000_16_cpu 2609 2850 192 PACKED_TO_PADDED_TORCH_2_100_3000_16_cuda:0 758 901 660 PACKED_TO_PADDED_TORCH_2_1000_300_1_cpu 467 612 1072 PACKED_TO_PADDED_TORCH_2_1000_300_1_cuda:0 772 905 648 PACKED_TO_PADDED_TORCH_2_1000_300_16_cpu 689 839 726 PACKED_TO_PADDED_TORCH_2_1000_300_16_cuda:0 789 1143 635 PACKED_TO_PADDED_TORCH_2_1000_3000_1_cpu 629 735 795 PACKED_TO_PADDED_TORCH_2_1000_3000_1_cuda:0 812 916 616 PACKED_TO_PADDED_TORCH_2_1000_3000_16_cpu 2716 3117 185 PACKED_TO_PADDED_TORCH_2_1000_3000_16_cuda:0 844 1288 593 PACKED_TO_PADDED_TORCH_10_100_300_1_cpu 2387 2557 210 PACKED_TO_PADDED_TORCH_10_100_300_1_cuda:0 4112 4993 122 PACKED_TO_PADDED_TORCH_10_100_300_16_cpu 3385 4254 148 PACKED_TO_PADDED_TORCH_10_100_300_16_cuda:0 3959 4902 127 PACKED_TO_PADDED_TORCH_10_100_3000_1_cpu 2918 3105 172 PACKED_TO_PADDED_TORCH_10_100_3000_1_cuda:0 4054 4450 124 PACKED_TO_PADDED_TORCH_10_100_3000_16_cpu 12748 13623 40 PACKED_TO_PADDED_TORCH_10_100_3000_16_cuda:0 4023 4395 125 PACKED_TO_PADDED_TORCH_10_1000_300_1_cpu 2258 2492 222 PACKED_TO_PADDED_TORCH_10_1000_300_1_cuda:0 3997 4312 126 PACKED_TO_PADDED_TORCH_10_1000_300_16_cpu 3404 3597 147 PACKED_TO_PADDED_TORCH_10_1000_300_16_cuda:0 3877 4227 129 PACKED_TO_PADDED_TORCH_10_1000_3000_1_cpu 2789 3054 180 PACKED_TO_PADDED_TORCH_10_1000_3000_1_cuda:0 3821 4402 131 PACKED_TO_PADDED_TORCH_10_1000_3000_16_cpu 11967 12963 42 PACKED_TO_PADDED_TORCH_10_1000_3000_16_cuda:0 3729 4290 135 PACKED_TO_PADDED_TORCH_32_100_300_1_cpu 6933 8152 73 PACKED_TO_PADDED_TORCH_32_100_300_1_cuda:0 11856 12287 43 PACKED_TO_PADDED_TORCH_32_100_300_16_cpu 9895 11205 51 PACKED_TO_PADDED_TORCH_32_100_300_16_cuda:0 12354 13596 41 PACKED_TO_PADDED_TORCH_32_100_3000_1_cpu 9516 10128 53 PACKED_TO_PADDED_TORCH_32_100_3000_1_cuda:0 12917 13597 39 PACKED_TO_PADDED_TORCH_32_100_3000_16_cpu 41209 43783 13 PACKED_TO_PADDED_TORCH_32_100_3000_16_cuda:0 12210 13288 41 PACKED_TO_PADDED_TORCH_32_1000_300_1_cpu 7179 7689 70 PACKED_TO_PADDED_TORCH_32_1000_300_1_cuda:0 11896 12381 43 PACKED_TO_PADDED_TORCH_32_1000_300_16_cpu 10127 15494 50 PACKED_TO_PADDED_TORCH_32_1000_300_16_cuda:0 12034 12817 42 PACKED_TO_PADDED_TORCH_32_1000_3000_1_cpu 8743 10251 58 PACKED_TO_PADDED_TORCH_32_1000_3000_1_cuda:0 12023 12908 42 PACKED_TO_PADDED_TORCH_32_1000_3000_16_cpu 39071 41777 13 PACKED_TO_PADDED_TORCH_32_1000_3000_16_cuda:0 11999 13690 42 -------------------------------------------------------------------------------- ``` Reviewed By: bottler, nikhilaravi, jcjohnson Differential Revision: D19870575 fbshipit-source-id: 23a2477b73373c411899633386c87ab034c3702a
This commit is contained in:
parent
8301163d24
commit
60f3c4e7d2
@ -10,7 +10,8 @@
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("face_areas_normals", &FaceAreasNormals);
|
||||
m.def("packed_to_padded_tensor", &packed_to_padded_tensor);
|
||||
m.def("packed_to_padded", &PackedToPadded);
|
||||
m.def("padded_to_packed", &PaddedToPacked);
|
||||
m.def("nn_points_idx", &NearestNeighborIdx);
|
||||
m.def("gather_scatter", &gather_scatter);
|
||||
m.def("rasterize_points", &RasterizePoints);
|
||||
|
@ -21,10 +21,12 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
// Cuda implementation.
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
|
||||
at::Tensor verts,
|
||||
at::Tensor faces);
|
||||
#endif
|
||||
|
||||
// Implementation which is exposed.
|
||||
std::tuple<at::Tensor, at::Tensor> FaceAreasNormals(
|
||||
|
@ -22,8 +22,10 @@
|
||||
// CPU implementation.
|
||||
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
// Cuda implementation.
|
||||
at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2);
|
||||
#endif
|
||||
|
||||
// Implementation which is exposed.
|
||||
at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) {
|
||||
|
@ -1,11 +1,42 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
// Kernel for inputs_packed of shape (F, D), where D > 1
|
||||
template <typename scalar_t>
|
||||
__global__ void packed_to_padded_tensor_kernel(
|
||||
const scalar_t* __restrict__ inputs,
|
||||
const long* __restrict__ first_idxs,
|
||||
__global__ void PackedToPaddedKernel(
|
||||
const scalar_t* __restrict__ inputs_packed,
|
||||
const int64_t* __restrict__ first_idxs,
|
||||
scalar_t* __restrict__ inputs_padded,
|
||||
const size_t batch_size,
|
||||
const size_t max_size,
|
||||
const size_t num_inputs,
|
||||
const size_t D) {
|
||||
// Batch elements split evenly across blocks (num blocks = batch_size) and
|
||||
// values for each element split across threads in the block. Each thread adds
|
||||
// the values of its respective input elements to the global inputs_padded
|
||||
// tensor.
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t batch_idx = blockIdx.x;
|
||||
|
||||
const int64_t start = first_idxs[batch_idx];
|
||||
const int64_t end =
|
||||
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
|
||||
const int num = end - start;
|
||||
for (size_t f = tid; f < num; f += blockDim.x) {
|
||||
for (size_t j = 0; j < D; ++j) {
|
||||
inputs_padded[batch_idx * max_size * D + f * D + j] =
|
||||
inputs_packed[(start + f) * D + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel for inputs of shape (F, 1)
|
||||
template <typename scalar_t>
|
||||
__global__ void PackedToPaddedKernelD1(
|
||||
const scalar_t* __restrict__ inputs_packed,
|
||||
const int64_t* __restrict__ first_idxs,
|
||||
scalar_t* __restrict__ inputs_padded,
|
||||
const size_t batch_size,
|
||||
const size_t max_size,
|
||||
@ -17,36 +48,155 @@ __global__ void packed_to_padded_tensor_kernel(
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t batch_idx = blockIdx.x;
|
||||
|
||||
const long start = first_idxs[batch_idx];
|
||||
const long end =
|
||||
const int64_t start = first_idxs[batch_idx];
|
||||
const int64_t end =
|
||||
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
|
||||
const int num_faces = end - start;
|
||||
for (size_t f = tid; f < num_faces; f += blockDim.x) {
|
||||
inputs_padded[batch_idx * max_size + f] = inputs[start + f];
|
||||
const int num = end - start;
|
||||
for (size_t f = tid; f < num; f += blockDim.x) {
|
||||
inputs_padded[batch_idx * max_size + f] = inputs_packed[start + f];
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor packed_to_padded_tensor_cuda(
|
||||
at::Tensor inputs,
|
||||
at::Tensor first_idxs,
|
||||
const long max_size) {
|
||||
const auto num_inputs = inputs.size(0);
|
||||
const auto batch_size = first_idxs.size(0);
|
||||
// Kernel for inputs_padded of shape (B, F, D), where D > 1
|
||||
template <typename scalar_t>
|
||||
__global__ void PaddedToPackedKernel(
|
||||
const scalar_t* __restrict__ inputs_padded,
|
||||
const int64_t* __restrict__ first_idxs,
|
||||
scalar_t* __restrict__ inputs_packed,
|
||||
const size_t batch_size,
|
||||
const size_t max_size,
|
||||
const size_t num_inputs,
|
||||
const size_t D) {
|
||||
// Batch elements split evenly across blocks (num blocks = batch_size) and
|
||||
// values for each element split across threads in the block. Each thread adds
|
||||
// the values of its respective input elements to the global inputs_packed
|
||||
// tensor.
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t batch_idx = blockIdx.x;
|
||||
|
||||
const int64_t start = first_idxs[batch_idx];
|
||||
const int64_t end =
|
||||
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
|
||||
const int num = end - start;
|
||||
for (size_t f = tid; f < num; f += blockDim.x) {
|
||||
for (size_t j = 0; j < D; ++j) {
|
||||
inputs_packed[(start + f) * D + j] =
|
||||
inputs_padded[batch_idx * max_size * D + f * D + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Kernel for inputs_padded of shape (B, F, 1)
|
||||
template <typename scalar_t>
|
||||
__global__ void PaddedToPackedKernelD1(
|
||||
const scalar_t* __restrict__ inputs_padded,
|
||||
const int64_t* __restrict__ first_idxs,
|
||||
scalar_t* __restrict__ inputs_packed,
|
||||
const size_t batch_size,
|
||||
const size_t max_size,
|
||||
const size_t num_inputs) {
|
||||
// Batch elements split evenly across blocks (num blocks = batch_size) and
|
||||
// values for each element split across threads in the block. Each thread adds
|
||||
// the values of its respective input elements to the global inputs_packed
|
||||
// tensor.
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t batch_idx = blockIdx.x;
|
||||
|
||||
const int64_t start = first_idxs[batch_idx];
|
||||
const int64_t end =
|
||||
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
|
||||
const int num = end - start;
|
||||
for (size_t f = tid; f < num; f += blockDim.x) {
|
||||
inputs_packed[start + f] = inputs_padded[batch_idx * max_size + f];
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor PackedToPaddedCuda(
|
||||
const at::Tensor inputs_packed,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t max_size) {
|
||||
const int64_t num_inputs = inputs_packed.size(0);
|
||||
const int64_t batch_size = first_idxs.size(0);
|
||||
|
||||
AT_ASSERTM(
|
||||
inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
|
||||
const int64_t D = inputs_packed.size(1);
|
||||
at::Tensor inputs_padded =
|
||||
at::zeros({batch_size, max_size}, inputs.options());
|
||||
at::zeros({batch_size, max_size, D}, inputs_packed.options());
|
||||
|
||||
const int threads = 512;
|
||||
const int blocks = batch_size;
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs.type(), "packed_to_padded_tensor_kernel", ([&] {
|
||||
packed_to_padded_tensor_kernel<scalar_t><<<blocks, threads>>>(
|
||||
inputs.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<long>(),
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
batch_size,
|
||||
max_size,
|
||||
num_inputs);
|
||||
}));
|
||||
if (D == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs_packed.type(), "packed_to_padded_d1_kernel", ([&] {
|
||||
PackedToPaddedKernelD1<scalar_t><<<blocks, threads>>>(
|
||||
inputs_packed.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<int64_t>(),
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
batch_size,
|
||||
max_size,
|
||||
num_inputs);
|
||||
}));
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs_packed.type(), "packed_to_padded_kernel", ([&] {
|
||||
PackedToPaddedKernel<scalar_t><<<blocks, threads>>>(
|
||||
inputs_packed.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<int64_t>(),
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
batch_size,
|
||||
max_size,
|
||||
num_inputs,
|
||||
D);
|
||||
}));
|
||||
}
|
||||
|
||||
return inputs_padded;
|
||||
}
|
||||
|
||||
at::Tensor PaddedToPackedCuda(
|
||||
const at::Tensor inputs_padded,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t num_inputs) {
|
||||
const int64_t batch_size = inputs_padded.size(0);
|
||||
const int64_t max_size = inputs_padded.size(1);
|
||||
|
||||
AT_ASSERTM(batch_size == first_idxs.size(0), "sizes mismatch");
|
||||
AT_ASSERTM(
|
||||
inputs_padded.dim() == 3,
|
||||
"inputs_padded must be a 3-dimensional tensor");
|
||||
const int64_t D = inputs_padded.size(2);
|
||||
|
||||
at::Tensor inputs_packed =
|
||||
at::zeros({num_inputs, D}, inputs_padded.options());
|
||||
|
||||
const int threads = 512;
|
||||
const int blocks = batch_size;
|
||||
|
||||
if (D == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs_padded.type(), "padded_to_packed_d1_kernel", ([&] {
|
||||
PaddedToPackedKernelD1<scalar_t><<<blocks, threads>>>(
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<int64_t>(),
|
||||
inputs_packed.data_ptr<scalar_t>(),
|
||||
batch_size,
|
||||
max_size,
|
||||
num_inputs);
|
||||
}));
|
||||
} else {
|
||||
AT_DISPATCH_FLOATING_TYPES(
|
||||
inputs_padded.type(), "padded_to_packed_kernel", ([&] {
|
||||
PaddedToPackedKernel<scalar_t><<<blocks, threads>>>(
|
||||
inputs_padded.data_ptr<scalar_t>(),
|
||||
first_idxs.data_ptr<int64_t>(),
|
||||
inputs_packed.data_ptr<scalar_t>(),
|
||||
batch_size,
|
||||
max_size,
|
||||
num_inputs,
|
||||
D);
|
||||
}));
|
||||
}
|
||||
|
||||
return inputs_packed;
|
||||
}
|
||||
|
@ -3,42 +3,96 @@
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
// PackedToPadded
|
||||
// Converts a packed tensor into a padded tensor, restoring the batch dimension.
|
||||
// Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors.
|
||||
//
|
||||
// Inputs:
|
||||
// inputs: FloatTensor of shape (F,), representing the packed batch tensor.
|
||||
// e.g. areas for faces in a batch of meshes.
|
||||
// inputs_packed: FloatTensor of shape (F, D), representing the packed batch
|
||||
// tensor, e.g. areas for faces in a batch of meshes.
|
||||
// first_idxs: LongTensor of shape (N,) where N is the number of
|
||||
// elements in the batch and `packed_first_idxs[i] = f`
|
||||
// elements in the batch and `first_idxs[i] = f`
|
||||
// means that the inputs for batch element i begin at
|
||||
// `inputs[f]`.
|
||||
// max_size: Max length of an element in the batch.
|
||||
// max_size: Max length of an element in the batch.
|
||||
// Returns:
|
||||
// inputs_padded: FloatTensor of shape (N, max_size) where max_size is max
|
||||
// inputs_padded: FloatTensor of shape (N, max_size, D) where max_size is max
|
||||
// of `sizes`. The values for batch element i which start at
|
||||
// `inputs[packed_first_idxs[i]]` will be copied to
|
||||
// `inputs_padded[i, :]``, with zeros padding out the extra
|
||||
// `inputs_packed[first_idxs[i]]` will be copied to
|
||||
// `inputs_padded[i, :]`, with zeros padding out the extra
|
||||
// inputs.
|
||||
//
|
||||
|
||||
// PaddedToPacked
|
||||
// Converts a padded tensor into a packed tensor.
|
||||
// Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors.
|
||||
//
|
||||
// Inputs:
|
||||
// inputs_padded: FloatTensor of shape (N, max_size, D), representing the
|
||||
// padded tensor, e.g. areas for faces in a batch of meshes.
|
||||
// first_idxs: LongTensor of shape (N,) where N is the number of
|
||||
// elements in the batch and `first_idxs[i] = f`
|
||||
// means that the inputs for batch element i begin at
|
||||
// `inputs_packed[f]`.
|
||||
// num_inputs: Number of packed entries (= F)
|
||||
// Returns:
|
||||
// inputs_packed: FloatTensor of shape (F, D), where
|
||||
// `inputs_packed[first_idx[i]:] = inputs_padded[i, :]`.
|
||||
//
|
||||
//
|
||||
|
||||
// Cpu implementation.
|
||||
at::Tensor PackedToPaddedCpu(
|
||||
const at::Tensor inputs_packed,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t max_size);
|
||||
|
||||
// Cpu implementation.
|
||||
at::Tensor PaddedToPackedCpu(
|
||||
const at::Tensor inputs_padded,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t num_inputs);
|
||||
|
||||
#ifdef WITH_CUDA
|
||||
// Cuda implementation.
|
||||
at::Tensor packed_to_padded_tensor_cuda(
|
||||
at::Tensor inputs,
|
||||
at::Tensor first_idxs,
|
||||
const long max_size);
|
||||
at::Tensor PackedToPaddedCuda(
|
||||
const at::Tensor inputs_packed,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t max_size);
|
||||
|
||||
// Cuda implementation.
|
||||
at::Tensor PaddedToPackedCuda(
|
||||
const at::Tensor inputs_padded,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t num_inputs);
|
||||
#endif
|
||||
|
||||
// Implementation which is exposed.
|
||||
at::Tensor packed_to_padded_tensor(
|
||||
at::Tensor inputs,
|
||||
at::Tensor first_idxs,
|
||||
const long max_size) {
|
||||
if (inputs.type().is_cuda()) {
|
||||
at::Tensor PackedToPadded(
|
||||
const at::Tensor inputs_packed,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t max_size) {
|
||||
if (inputs_packed.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return packed_to_padded_tensor_cuda(inputs, first_idxs, max_size);
|
||||
return PackedToPaddedCuda(inputs_packed, first_idxs, max_size);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
AT_ERROR("Not implemented on the CPU.");
|
||||
return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
|
||||
}
|
||||
|
||||
// Implementation which is exposed.
|
||||
at::Tensor PaddedToPacked(
|
||||
const at::Tensor inputs_padded,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t num_inputs) {
|
||||
if (inputs_padded.type().is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
|
||||
}
|
||||
|
@ -0,0 +1,65 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
at::Tensor PackedToPaddedCpu(
|
||||
const at::Tensor inputs_packed,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t max_size) {
|
||||
const int64_t num_inputs = inputs_packed.size(0);
|
||||
const int64_t batch_size = first_idxs.size(0);
|
||||
|
||||
AT_ASSERTM(
|
||||
inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
|
||||
const int64_t D = inputs_packed.size(1);
|
||||
|
||||
torch::Tensor inputs_padded =
|
||||
torch::zeros({batch_size, max_size, D}, inputs_packed.options());
|
||||
|
||||
auto inputs_packed_a = inputs_packed.accessor<float, 2>();
|
||||
auto first_idxs_a = first_idxs.accessor<int64_t, 1>();
|
||||
auto inputs_padded_a = inputs_padded.accessor<float, 3>();
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
const int64_t start = first_idxs_a[b];
|
||||
const int64_t end = b + 1 < batch_size ? first_idxs_a[b + 1] : num_inputs;
|
||||
const int64_t num = end - start;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
for (int j = 0; j < D; ++j) {
|
||||
inputs_padded_a[b][i][j] = inputs_packed_a[start + i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
return inputs_padded;
|
||||
}
|
||||
|
||||
at::Tensor PaddedToPackedCpu(
|
||||
const at::Tensor inputs_padded,
|
||||
const at::Tensor first_idxs,
|
||||
const int64_t num_inputs) {
|
||||
const int64_t batch_size = inputs_padded.size(0);
|
||||
const int64_t max_size = inputs_padded.size(1);
|
||||
|
||||
AT_ASSERTM(
|
||||
inputs_padded.dim() == 3, "inputs_padded must be a 3-dimensional tensor");
|
||||
const int64_t D = inputs_padded.size(2);
|
||||
|
||||
torch::Tensor inputs_packed =
|
||||
torch::zeros({num_inputs, D}, inputs_padded.options());
|
||||
|
||||
auto inputs_padded_a = inputs_padded.accessor<float, 3>();
|
||||
auto first_idxs_a = first_idxs.accessor<int64_t, 1>();
|
||||
auto inputs_packed_a = inputs_packed.accessor<float, 2>();
|
||||
|
||||
for (int b = 0; b < batch_size; ++b) {
|
||||
const int64_t start = first_idxs_a[b];
|
||||
const int64_t end = b + 1 < batch_size ? first_idxs_a[b + 1] : num_inputs;
|
||||
const int64_t num = end - start;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
for (int j = 0; j < D; ++j) {
|
||||
inputs_packed_a[start + i][j] = inputs_padded_a[b][i][j];
|
||||
}
|
||||
}
|
||||
}
|
||||
return inputs_packed;
|
||||
}
|
@ -4,6 +4,7 @@
|
||||
from .cubify import cubify
|
||||
from .graph_conv import GraphConv
|
||||
from .nearest_neighbor_points import nn_points_idx
|
||||
from .packed_to_padded import packed_to_padded, padded_to_packed
|
||||
from .sample_points_from_meshes import sample_points_from_meshes
|
||||
from .subdivide_meshes import SubdivideMeshes
|
||||
from .vert_align import vert_align
|
||||
|
170
pytorch3d/ops/packed_to_padded.py
Normal file
170
pytorch3d/ops/packed_to_padded.py
Normal file
@ -0,0 +1,170 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
class _PackedToPadded(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper for packed_to_padded C++/CUDA implementations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, first_idxs, max_size):
|
||||
"""
|
||||
Args:
|
||||
ctx: Context object used to calculate gradients.
|
||||
inputs: FloatTensor of shape (F, D), representing the packed batch tensor.
|
||||
e.g. areas for faces in a batch of meshes.
|
||||
first_idxs: LongTensor of shape (N,) where N is the number of
|
||||
elements in the batch and `first_idxs[i] = f`
|
||||
means that the inputs for batch element i begin at `inputs[f]`.
|
||||
max_size: Max length of an element in the batch.
|
||||
|
||||
Returns:
|
||||
inputs_padded: FloatTensor of shape (N, max_size, D) where max_size is max
|
||||
of `sizes`. The values for batch element i which start at
|
||||
`inputs[first_idxs[i]]` will be copied to `inputs_padded[i, :]`,
|
||||
with zeros padding out the extra inputs.
|
||||
"""
|
||||
if not (inputs.dim() == 2):
|
||||
raise ValueError("input can only be 2-dimensional.")
|
||||
if not (first_idxs.dim() == 1):
|
||||
raise ValueError("first_idxs can only be 1-dimensional.")
|
||||
if not (inputs.dtype == torch.float32):
|
||||
raise ValueError("input has to be of type torch.float32.")
|
||||
if not (first_idxs.dtype == torch.int64):
|
||||
raise ValueError("first_idxs has to be of type torch.int64.")
|
||||
if not isinstance(max_size, int):
|
||||
raise ValueError("max_size has to be int.")
|
||||
|
||||
ctx.save_for_backward(first_idxs)
|
||||
ctx.num_inputs = int(inputs.shape[0])
|
||||
inputs, first_idxs = inputs.contiguous(), first_idxs.contiguous()
|
||||
inputs_padded = _C.packed_to_padded(inputs, first_idxs, max_size)
|
||||
return inputs_padded
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = grad_output.contiguous()
|
||||
first_idxs = ctx.saved_tensors[0]
|
||||
num_inputs = ctx.num_inputs
|
||||
grad_input = _C.padded_to_packed(grad_output, first_idxs, num_inputs)
|
||||
return grad_input, None, None
|
||||
|
||||
|
||||
def packed_to_padded(inputs, first_idxs, max_size):
|
||||
"""
|
||||
Torch wrapper that handles allowed input shapes. See description below.
|
||||
|
||||
Args:
|
||||
inputs: FloatTensor of shape (F,) or (F, D), representing the packed batch tensor.
|
||||
e.g. areas for faces in a batch of meshes.
|
||||
first_idxs: LongTensor of shape (N,) where N is the number of
|
||||
elements in the batch and `first_idxs[i] = f`
|
||||
means that the inputs for batch element i begin at `inputs[f]`.
|
||||
max_size: Max length of an element in the batch.
|
||||
|
||||
Returns:
|
||||
inputs_padded: FloatTensor of shape (N, max_size) or (N, max_size, D) where max_size is
|
||||
max of `sizes`. The values for batch element i which start at
|
||||
`inputs[first_idxs[i]]` will be copied to `inputs_padded[i, :]`,
|
||||
with zeros padding out the extra inputs.
|
||||
|
||||
To handle the allowed input shapes, we convert the inputs tensor of shape (F,) to (F, 1).
|
||||
We reshape the output back to (N, max_size) from (N, max_size, 1).
|
||||
"""
|
||||
# if inputs is of shape (F,), reshape into (F, 1)
|
||||
flat = False
|
||||
if inputs.dim() == 1:
|
||||
flat = True
|
||||
inputs = inputs.unsqueeze(1)
|
||||
inputs_padded = _PackedToPadded.apply(inputs, first_idxs, max_size)
|
||||
# if flat is True, reshape output to (N, max_size) from (N, max_size, 1)
|
||||
if flat:
|
||||
inputs_padded = inputs_padded.squeeze(2)
|
||||
return inputs_padded
|
||||
|
||||
|
||||
class _PaddedToPacked(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper for padded_to_packed C++/CUDA implementations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, first_idxs, num_inputs):
|
||||
"""
|
||||
Args:
|
||||
ctx: Context object used to calculate gradients.
|
||||
inputs: FloatTensor of shape (N, max_size, D), representing the padded tensor.
|
||||
e.g. areas for faces in a batch of meshes.
|
||||
first_idxs: LongTensor of shape (N,) where N is the number of
|
||||
elements in the batch and `first_idxs[i] = f`
|
||||
means that the inputs for batch element i begin at `inputs_packed[f]`.
|
||||
num_inputs: Number of packed entries (= F)
|
||||
|
||||
Returns:
|
||||
inputs_packed: FloatTensor of shape (F, D) where
|
||||
`inputs_packed[first_idx[i]:] = inputs[i, :]`.
|
||||
"""
|
||||
if not (inputs.dim() == 3):
|
||||
raise ValueError("input can only be 3-dimensional.")
|
||||
if not (first_idxs.dim() == 1):
|
||||
raise ValueError("first_idxs can only be 1-dimensional.")
|
||||
if not (inputs.dtype == torch.float32):
|
||||
raise ValueError("input has to be of type torch.float32.")
|
||||
if not (first_idxs.dtype == torch.int64):
|
||||
raise ValueError("first_idxs has to be of type torch.int64.")
|
||||
if not isinstance(num_inputs, int):
|
||||
raise ValueError("max_size has to be int.")
|
||||
|
||||
ctx.save_for_backward(first_idxs)
|
||||
ctx.max_size = inputs.shape[1]
|
||||
inputs, first_idxs = inputs.contiguous(), first_idxs.contiguous()
|
||||
inputs_packed = _C.padded_to_packed(inputs, first_idxs, num_inputs)
|
||||
return inputs_packed
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = grad_output.contiguous()
|
||||
first_idxs = ctx.saved_tensors[0]
|
||||
max_size = ctx.max_size
|
||||
grad_input = _C.packed_to_padded(grad_output, first_idxs, max_size)
|
||||
return grad_input, None, None
|
||||
|
||||
|
||||
def padded_to_packed(inputs, first_idxs, num_inputs):
|
||||
"""
|
||||
Torch wrapper that handles allowed input shapes. See description below.
|
||||
|
||||
Args:
|
||||
inputs: FloatTensor of shape (N, max_size) or (N, max_size, D), representing the
|
||||
padded tensor. e.g. areas for faces in a batch of meshes.
|
||||
first_idxs: LongTensor of shape (N,) where N is the number of
|
||||
elements in the batch and `first_idxs[i] = f`
|
||||
means that the inputs for batch element i begin at `inputs_packed[f]`.
|
||||
num_inputs: Number of packed entries (= F)
|
||||
|
||||
Returns:
|
||||
inputs_packed: FloatTensor of shape (F,) or (F, D) where
|
||||
`inputs_packed[first_idx[i]:] = inputs[i, :]`.
|
||||
|
||||
To handle the allowed input shapes, we convert the inputs tensor of shape (N, max_size)
|
||||
to (N, max_size, 1). We reshape the output back to (F,) from (F, 1).
|
||||
"""
|
||||
# if inputs is of shape (N, max_size), reshape into (N, max_size, 1))
|
||||
flat = False
|
||||
if inputs.dim() == 2:
|
||||
flat = True
|
||||
inputs = inputs.unsqueeze(2)
|
||||
inputs_packed = _PaddedToPacked.apply(inputs, first_idxs, num_inputs)
|
||||
# if flat is True, reshape output to (F,) from (F, 1)
|
||||
if flat:
|
||||
inputs_packed = inputs_packed.squeeze(1)
|
||||
return inputs_packed
|
@ -12,6 +12,8 @@ import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
|
||||
from .packed_to_padded import packed_to_padded
|
||||
|
||||
|
||||
def sample_points_from_meshes(
|
||||
meshes, num_samples: int = 10000, return_normals: bool = False
|
||||
@ -55,7 +57,7 @@ def sample_points_from_meshes(
|
||||
verts, faces
|
||||
) # Face areas can be zero.
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
areas_padded = _C.packed_to_padded_tensor(
|
||||
areas_padded = packed_to_padded(
|
||||
areas, mesh_to_face[meshes.valid], max_faces
|
||||
) # (N, F)
|
||||
|
||||
|
47
tests/bm_packed_to_padded.py
Normal file
47
tests/bm_packed_to_padded.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from itertools import product
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from test_packed_to_padded import TestPackedToPadded
|
||||
|
||||
|
||||
def bm_packed_to_padded() -> None:
|
||||
kwargs_list = []
|
||||
backend = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
backend.append("cuda:0")
|
||||
|
||||
num_meshes = [2, 10, 32]
|
||||
num_verts = [100, 1000]
|
||||
num_faces = [300, 3000]
|
||||
num_ds = [0, 1, 16]
|
||||
|
||||
test_cases = product(num_meshes, num_verts, num_faces, num_ds, backend)
|
||||
for case in test_cases:
|
||||
n, v, f, d, b = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"num_meshes": n,
|
||||
"num_verts": v,
|
||||
"num_faces": f,
|
||||
"num_d": d,
|
||||
"device": b,
|
||||
}
|
||||
)
|
||||
benchmark(
|
||||
TestPackedToPadded.packed_to_padded_with_init,
|
||||
"PACKED_TO_PADDED",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
benchmark(
|
||||
TestPackedToPadded.packed_to_padded_with_init_torch,
|
||||
"PACKED_TO_PADDED_TORCH",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
@ -10,56 +10,30 @@ from test_sample_points_from_meshes import TestSamplePoints
|
||||
|
||||
|
||||
def bm_sample_points() -> None:
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
kwargs_list = []
|
||||
num_meshes = [2, 10, 32]
|
||||
num_verts = [100, 1000]
|
||||
num_faces = [300, 3000]
|
||||
num_samples = [5000, 10000]
|
||||
test_cases = product(num_meshes, num_verts, num_faces, num_samples)
|
||||
for case in test_cases:
|
||||
n, v, f, s = case
|
||||
kwargs_list.append(
|
||||
{
|
||||
"num_meshes": n,
|
||||
"num_verts": v,
|
||||
"num_faces": f,
|
||||
"num_samples": s,
|
||||
"device": device,
|
||||
}
|
||||
)
|
||||
benchmark(
|
||||
TestSamplePoints.sample_points_with_init,
|
||||
"SAMPLE_MESH",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
backend = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
backend.append("cuda:0")
|
||||
kwargs_list = []
|
||||
backend_cuda = ["False"]
|
||||
if torch.cuda.is_available():
|
||||
backend_cuda.append("True")
|
||||
|
||||
num_meshes = [2, 10, 32]
|
||||
num_verts = [100, 1000]
|
||||
num_faces = [300, 3000]
|
||||
|
||||
test_cases = product(num_meshes, num_verts, num_faces, backend_cuda)
|
||||
num_samples = [5000, 10000]
|
||||
test_cases = product(num_meshes, num_verts, num_faces, num_samples, backend)
|
||||
for case in test_cases:
|
||||
n, v, f, c = case
|
||||
n, v, f, s, b = case
|
||||
kwargs_list.append(
|
||||
{"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c}
|
||||
{
|
||||
"num_meshes": n,
|
||||
"num_verts": v,
|
||||
"num_faces": f,
|
||||
"num_samples": s,
|
||||
"device": b,
|
||||
}
|
||||
)
|
||||
benchmark(
|
||||
TestSamplePoints.face_areas_with_init,
|
||||
"FACE_AREAS",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
TestSamplePoints.packed_to_padded_with_init,
|
||||
"PACKED_TO_PADDED",
|
||||
TestSamplePoints.sample_points_with_init,
|
||||
"SAMPLE_MESH",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
296
tests/test_packed_to_padded.py
Normal file
296
tests/test_packed_to_padded.py
Normal file
@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.ops import packed_to_padded, padded_to_packed
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestPackedToPadded(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(1)
|
||||
|
||||
@staticmethod
|
||||
def init_meshes(
|
||||
num_meshes: int = 10,
|
||||
num_verts: int = 1000,
|
||||
num_faces: int = 3000,
|
||||
device: str = "cpu",
|
||||
):
|
||||
device = torch.device(device)
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
for _ in range(num_meshes):
|
||||
verts = torch.rand(
|
||||
(num_verts, 3), dtype=torch.float32, device=device
|
||||
)
|
||||
faces = torch.randint(
|
||||
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
||||
)
|
||||
verts_list.append(verts)
|
||||
faces_list.append(faces)
|
||||
meshes = Meshes(verts_list, faces_list)
|
||||
|
||||
return meshes
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_python(inputs, first_idxs, max_size, device):
|
||||
"""
|
||||
PyTorch implementation of packed_to_padded function.
|
||||
"""
|
||||
num_meshes = first_idxs.size(0)
|
||||
D = inputs.shape[1] if inputs.dim() == 2 else 0
|
||||
if D == 0:
|
||||
inputs_padded = torch.zeros((num_meshes, max_size), device=device)
|
||||
else:
|
||||
inputs_padded = torch.zeros(
|
||||
(num_meshes, max_size, D), device=device
|
||||
)
|
||||
for m in range(num_meshes):
|
||||
s = first_idxs[m]
|
||||
if m == num_meshes - 1:
|
||||
f = inputs.shape[0]
|
||||
else:
|
||||
f = first_idxs[m + 1]
|
||||
inputs_padded[m, :f] = inputs[s:f]
|
||||
|
||||
return inputs_padded
|
||||
|
||||
@staticmethod
|
||||
def padded_to_packed_python(inputs, first_idxs, num_inputs, device):
|
||||
"""
|
||||
PyTorch implementation of padded_to_packed function.
|
||||
"""
|
||||
num_meshes = inputs.size(0)
|
||||
D = inputs.shape[2] if inputs.dim() == 3 else 0
|
||||
if D == 0:
|
||||
inputs_packed = torch.zeros((num_inputs,), device=device)
|
||||
else:
|
||||
inputs_packed = torch.zeros((num_inputs, D), device=device)
|
||||
for m in range(num_meshes):
|
||||
s = first_idxs[m]
|
||||
if m == num_meshes - 1:
|
||||
f = num_inputs
|
||||
else:
|
||||
f = first_idxs[m + 1]
|
||||
inputs_packed[s:f] = inputs[m, :f]
|
||||
|
||||
return inputs_packed
|
||||
|
||||
def _test_packed_to_padded_helper(self, D, device):
|
||||
"""
|
||||
Check the results from packed_to_padded and PyTorch implementations
|
||||
are the same.
|
||||
"""
|
||||
meshes = self.init_meshes(16, 100, 300, device=device)
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
if D == 0:
|
||||
values = torch.rand(
|
||||
(faces.shape[0],), device=device, requires_grad=True
|
||||
)
|
||||
else:
|
||||
values = torch.rand(
|
||||
(faces.shape[0], D), device=device, requires_grad=True
|
||||
)
|
||||
values_torch = values.detach().clone()
|
||||
values_torch.requires_grad = True
|
||||
values_padded = packed_to_padded(
|
||||
values, mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
values_padded_torch = TestPackedToPadded.packed_to_padded_python(
|
||||
values_torch, mesh_to_faces_packed_first_idx, max_faces, device
|
||||
)
|
||||
# check forward
|
||||
self.assertClose(values_padded, values_padded_torch)
|
||||
|
||||
# check backward
|
||||
if D == 0:
|
||||
grad_inputs = torch.rand((len(meshes), max_faces), device=device)
|
||||
else:
|
||||
grad_inputs = torch.rand((len(meshes), max_faces, D), device=device)
|
||||
values_padded.backward(grad_inputs)
|
||||
grad_outputs = values.grad
|
||||
values_padded_torch.backward(grad_inputs)
|
||||
grad_outputs_torch1 = values_torch.grad
|
||||
grad_outputs_torch2 = TestPackedToPadded.padded_to_packed_python(
|
||||
grad_inputs,
|
||||
mesh_to_faces_packed_first_idx,
|
||||
values.size(0),
|
||||
device=device,
|
||||
)
|
||||
self.assertClose(grad_outputs, grad_outputs_torch1)
|
||||
self.assertClose(grad_outputs, grad_outputs_torch2)
|
||||
|
||||
def test_packed_to_padded_flat_cpu(self):
|
||||
self._test_packed_to_padded_helper(0, "cpu")
|
||||
|
||||
def test_packed_to_padded_D1_cpu(self):
|
||||
self._test_packed_to_padded_helper(1, "cpu")
|
||||
|
||||
def test_packed_to_padded_D16_cpu(self):
|
||||
self._test_packed_to_padded_helper(16, "cpu")
|
||||
|
||||
def test_packed_to_padded_flat_cuda(self):
|
||||
self._test_packed_to_padded_helper(0, "cuda:0")
|
||||
|
||||
def test_packed_to_padded_D1_cuda(self):
|
||||
self._test_packed_to_padded_helper(1, "cuda:0")
|
||||
|
||||
def test_packed_to_padded_D16_cuda(self):
|
||||
self._test_packed_to_padded_helper(16, "cuda:0")
|
||||
|
||||
def _test_padded_to_packed_helper(self, D, device):
|
||||
"""
|
||||
Check the results from packed_to_padded and PyTorch implementations
|
||||
are the same.
|
||||
"""
|
||||
meshes = self.init_meshes(16, 100, 300, device=device)
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
||||
max_faces = num_faces_per_mesh.max().item()
|
||||
if D == 0:
|
||||
values = torch.rand((len(meshes), max_faces), device=device)
|
||||
else:
|
||||
values = torch.rand((len(meshes), max_faces, D), device=device)
|
||||
for i, num in enumerate(num_faces_per_mesh):
|
||||
values[i, num:] = 0
|
||||
values.requires_grad = True
|
||||
values_torch = values.detach().clone()
|
||||
values_torch.requires_grad = True
|
||||
values_packed = padded_to_packed(
|
||||
values,
|
||||
mesh_to_faces_packed_first_idx,
|
||||
num_faces_per_mesh.sum().item(),
|
||||
)
|
||||
values_packed_torch = TestPackedToPadded.padded_to_packed_python(
|
||||
values_torch,
|
||||
mesh_to_faces_packed_first_idx,
|
||||
num_faces_per_mesh.sum().item(),
|
||||
device,
|
||||
)
|
||||
# check forward
|
||||
self.assertClose(values_packed, values_packed_torch)
|
||||
|
||||
# check backward
|
||||
if D == 0:
|
||||
grad_inputs = torch.rand(
|
||||
(num_faces_per_mesh.sum().item()), device=device
|
||||
)
|
||||
else:
|
||||
grad_inputs = torch.rand(
|
||||
(num_faces_per_mesh.sum().item(), D), device=device
|
||||
)
|
||||
values_packed.backward(grad_inputs)
|
||||
grad_outputs = values.grad
|
||||
values_packed_torch.backward(grad_inputs)
|
||||
grad_outputs_torch1 = values_torch.grad
|
||||
grad_outputs_torch2 = TestPackedToPadded.packed_to_padded_python(
|
||||
grad_inputs,
|
||||
mesh_to_faces_packed_first_idx,
|
||||
values.size(1),
|
||||
device=device,
|
||||
)
|
||||
self.assertClose(grad_outputs, grad_outputs_torch1)
|
||||
self.assertClose(grad_outputs, grad_outputs_torch2)
|
||||
|
||||
def test_padded_to_packed_flat_cpu(self):
|
||||
self._test_padded_to_packed_helper(0, "cpu")
|
||||
|
||||
def test_padded_to_packed_D1_cpu(self):
|
||||
self._test_padded_to_packed_helper(1, "cpu")
|
||||
|
||||
def test_padded_to_packed_D16_cpu(self):
|
||||
self._test_padded_to_packed_helper(16, "cpu")
|
||||
|
||||
def test_padded_to_packed_flat_cuda(self):
|
||||
self._test_padded_to_packed_helper(0, "cuda:0")
|
||||
|
||||
def test_padded_to_packed_D1_cuda(self):
|
||||
self._test_padded_to_packed_helper(1, "cuda:0")
|
||||
|
||||
def test_padded_to_packed_D16_cuda(self):
|
||||
self._test_padded_to_packed_helper(16, "cuda:0")
|
||||
|
||||
def test_invalid_inputs_shapes(self, device="cuda:0"):
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "input can only be 2-dimensional."
|
||||
):
|
||||
values = torch.rand((100, 50, 2), device=device)
|
||||
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
|
||||
packed_to_padded(values, first_idxs, 100)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "input can only be 3-dimensional."
|
||||
):
|
||||
values = torch.rand((100,), device=device)
|
||||
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
|
||||
padded_to_packed(values, first_idxs, 20)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "input can only be 3-dimensional."
|
||||
):
|
||||
values = torch.rand((100, 50, 2, 2), device=device)
|
||||
first_idxs = torch.tensor([0, 80], dtype=torch.int64, device=device)
|
||||
padded_to_packed(values, first_idxs, 20)
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_with_init(
|
||||
num_meshes: int,
|
||||
num_verts: int,
|
||||
num_faces: int,
|
||||
num_d: int,
|
||||
device: str = "cpu",
|
||||
):
|
||||
meshes = TestPackedToPadded.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
if num_d == 0:
|
||||
values = torch.rand((faces.shape[0],), device=meshes.device)
|
||||
else:
|
||||
values = torch.rand((faces.shape[0], num_d), device=meshes.device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def out():
|
||||
packed_to_padded(values, mesh_to_faces_packed_first_idx, max_faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_with_init_torch(
|
||||
num_meshes: int,
|
||||
num_verts: int,
|
||||
num_faces: int,
|
||||
num_d: int,
|
||||
device: str = "cpu",
|
||||
):
|
||||
meshes = TestPackedToPadded.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
if num_d == 0:
|
||||
values = torch.rand((faces.shape[0],), device=meshes.device)
|
||||
else:
|
||||
values = torch.rand((faces.shape[0], num_d), device=meshes.device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def out():
|
||||
TestPackedToPadded.packed_to_padded_python(
|
||||
values, mesh_to_faces_packed_first_idx, max_faces, device
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return out
|
@ -294,48 +294,6 @@ class TestSamplePoints(unittest.TestCase):
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_tensor(inputs, first_idxs, max_size):
|
||||
"""
|
||||
PyTorch implementation of cuda packed_to_padded_tensor function.
|
||||
"""
|
||||
num_meshes = first_idxs.size(0)
|
||||
inputs_padded = torch.zeros((num_meshes, max_size))
|
||||
for m in range(num_meshes):
|
||||
s = first_idxs[m]
|
||||
if m == num_meshes - 1:
|
||||
f = inputs.size(0)
|
||||
else:
|
||||
f = first_idxs[m + 1]
|
||||
inputs_padded[m, :f] = inputs[s:f]
|
||||
|
||||
return inputs_padded
|
||||
|
||||
def test_packed_to_padded_tensor(self):
|
||||
"""
|
||||
Check the results from packed_to_padded cuda and PyTorch implementions
|
||||
are the same.
|
||||
"""
|
||||
meshes = self.init_meshes(1, 3, 5, device="cuda:0")
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
areas, _ = _C.face_areas_normals(verts, faces)
|
||||
areas_padded = _C.packed_to_padded_tensor(
|
||||
areas, mesh_to_faces_packed_first_idx, max_faces
|
||||
).cpu()
|
||||
areas_padded_cpu = TestSamplePoints.packed_to_padded_tensor(
|
||||
areas, mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
self.assertTrue(torch.allclose(areas_padded, areas_padded_cpu))
|
||||
with self.assertRaises(Exception) as err:
|
||||
_C.packed_to_padded_tensor(
|
||||
areas.cpu(), mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
self.assertTrue("Not implemented on the CPU" in str(err.exception))
|
||||
|
||||
@staticmethod
|
||||
def sample_points_with_init(
|
||||
num_meshes: int,
|
||||
@ -344,7 +302,6 @@ class TestSamplePoints(unittest.TestCase):
|
||||
num_samples: int,
|
||||
device: str = "cpu",
|
||||
):
|
||||
device = torch.device(device)
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
for _ in range(num_meshes):
|
||||
@ -366,32 +323,3 @@ class TestSamplePoints(unittest.TestCase):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return sample_points
|
||||
|
||||
@staticmethod
|
||||
def packed_to_padded_with_init(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: str = True
|
||||
):
|
||||
device = "cuda" if cuda else "cpu"
|
||||
meshes = TestSamplePoints.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_faces_packed_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
|
||||
areas, _ = _C.face_areas_normals(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def packed_to_padded():
|
||||
if cuda:
|
||||
_C.packed_to_padded_tensor(
|
||||
areas, mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
else:
|
||||
TestSamplePoints.packed_to_padded_tensor(
|
||||
areas, mesh_to_faces_packed_first_idx, max_faces
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return packed_to_padded
|
||||
|
Loading…
x
Reference in New Issue
Block a user