cpp support for packed to padded

Summary:
Cpu implementation for packed to padded and added gradients
```
Benchmark                                     Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
PACKED_TO_PADDED_2_100_300_1_cpu                    138             221           3625
PACKED_TO_PADDED_2_100_300_1_cuda:0                 184             261           2716
PACKED_TO_PADDED_2_100_300_16_cpu                   555             726            901
PACKED_TO_PADDED_2_100_300_16_cuda:0                179             260           2794
PACKED_TO_PADDED_2_100_3000_1_cpu                   396             519           1262
PACKED_TO_PADDED_2_100_3000_1_cuda:0                181             274           2764
PACKED_TO_PADDED_2_100_3000_16_cpu                 4517            5003            111
PACKED_TO_PADDED_2_100_3000_16_cuda:0               224             397           2235
PACKED_TO_PADDED_2_1000_300_1_cpu                   138             212           3616
PACKED_TO_PADDED_2_1000_300_1_cuda:0                180             282           2775
PACKED_TO_PADDED_2_1000_300_16_cpu                  565             711            885
PACKED_TO_PADDED_2_1000_300_16_cuda:0               179             264           2797
PACKED_TO_PADDED_2_1000_3000_1_cpu                  389             494           1287
PACKED_TO_PADDED_2_1000_3000_1_cuda:0               180             271           2777
PACKED_TO_PADDED_2_1000_3000_16_cpu                4522            5170            111
PACKED_TO_PADDED_2_1000_3000_16_cuda:0              216             286           2313
PACKED_TO_PADDED_10_100_300_1_cpu                   251             345           1995
PACKED_TO_PADDED_10_100_300_1_cuda:0                178             262           2806
PACKED_TO_PADDED_10_100_300_16_cpu                 2354            2750            213
PACKED_TO_PADDED_10_100_300_16_cuda:0               178             291           2814
PACKED_TO_PADDED_10_100_3000_1_cpu                 1519            1786            330
PACKED_TO_PADDED_10_100_3000_1_cuda:0               179             237           2791
PACKED_TO_PADDED_10_100_3000_16_cpu               24705           25879             21
PACKED_TO_PADDED_10_100_3000_16_cuda:0              228             316           2191
PACKED_TO_PADDED_10_1000_300_1_cpu                  261             432           1919
PACKED_TO_PADDED_10_1000_300_1_cuda:0               181             261           2756
PACKED_TO_PADDED_10_1000_300_16_cpu                2349            2770            213
PACKED_TO_PADDED_10_1000_300_16_cuda:0              180             256           2782
PACKED_TO_PADDED_10_1000_3000_1_cpu                1613            1929            310
PACKED_TO_PADDED_10_1000_3000_1_cuda:0              183             253           2739
PACKED_TO_PADDED_10_1000_3000_16_cpu              22041           23653             23
PACKED_TO_PADDED_10_1000_3000_16_cuda:0             220             343           2270
PACKED_TO_PADDED_32_100_300_1_cpu                   555             750            901
PACKED_TO_PADDED_32_100_300_1_cuda:0                188             282           2661
PACKED_TO_PADDED_32_100_300_16_cpu                 7550            8131             67
PACKED_TO_PADDED_32_100_300_16_cuda:0               181             272           2770
PACKED_TO_PADDED_32_100_3000_1_cpu                 4574            6327            110
PACKED_TO_PADDED_32_100_3000_1_cuda:0               173             254           2884
PACKED_TO_PADDED_32_100_3000_16_cpu               70366           72563              8
PACKED_TO_PADDED_32_100_3000_16_cuda:0              349             654           1433
PACKED_TO_PADDED_32_1000_300_1_cpu                  612             728            818
PACKED_TO_PADDED_32_1000_300_1_cuda:0               189             295           2647
PACKED_TO_PADDED_32_1000_300_16_cpu                7699            8254             65
PACKED_TO_PADDED_32_1000_300_16_cuda:0              189             311           2646
PACKED_TO_PADDED_32_1000_3000_1_cpu                5105            5261             98
PACKED_TO_PADDED_32_1000_3000_1_cuda:0              191             260           2625
PACKED_TO_PADDED_32_1000_3000_16_cpu              87073           92708              6
PACKED_TO_PADDED_32_1000_3000_16_cuda:0             344             425           1455
--------------------------------------------------------------------------------

Benchmark                                           Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
PACKED_TO_PADDED_TORCH_2_100_300_1_cpu                    492             627           1016
PACKED_TO_PADDED_TORCH_2_100_300_1_cuda:0                 768             975            652
PACKED_TO_PADDED_TORCH_2_100_300_16_cpu                   659             804            760
PACKED_TO_PADDED_TORCH_2_100_300_16_cuda:0                781             918            641
PACKED_TO_PADDED_TORCH_2_100_3000_1_cpu                   624             734            802
PACKED_TO_PADDED_TORCH_2_100_3000_1_cuda:0                778             929            643
PACKED_TO_PADDED_TORCH_2_100_3000_16_cpu                 2609            2850            192
PACKED_TO_PADDED_TORCH_2_100_3000_16_cuda:0               758             901            660
PACKED_TO_PADDED_TORCH_2_1000_300_1_cpu                   467             612           1072
PACKED_TO_PADDED_TORCH_2_1000_300_1_cuda:0                772             905            648
PACKED_TO_PADDED_TORCH_2_1000_300_16_cpu                  689             839            726
PACKED_TO_PADDED_TORCH_2_1000_300_16_cuda:0               789            1143            635
PACKED_TO_PADDED_TORCH_2_1000_3000_1_cpu                  629             735            795
PACKED_TO_PADDED_TORCH_2_1000_3000_1_cuda:0               812             916            616
PACKED_TO_PADDED_TORCH_2_1000_3000_16_cpu                2716            3117            185
PACKED_TO_PADDED_TORCH_2_1000_3000_16_cuda:0              844            1288            593
PACKED_TO_PADDED_TORCH_10_100_300_1_cpu                  2387            2557            210
PACKED_TO_PADDED_TORCH_10_100_300_1_cuda:0               4112            4993            122
PACKED_TO_PADDED_TORCH_10_100_300_16_cpu                 3385            4254            148
PACKED_TO_PADDED_TORCH_10_100_300_16_cuda:0              3959            4902            127
PACKED_TO_PADDED_TORCH_10_100_3000_1_cpu                 2918            3105            172
PACKED_TO_PADDED_TORCH_10_100_3000_1_cuda:0              4054            4450            124
PACKED_TO_PADDED_TORCH_10_100_3000_16_cpu               12748           13623             40
PACKED_TO_PADDED_TORCH_10_100_3000_16_cuda:0             4023            4395            125
PACKED_TO_PADDED_TORCH_10_1000_300_1_cpu                 2258            2492            222
PACKED_TO_PADDED_TORCH_10_1000_300_1_cuda:0              3997            4312            126
PACKED_TO_PADDED_TORCH_10_1000_300_16_cpu                3404            3597            147
PACKED_TO_PADDED_TORCH_10_1000_300_16_cuda:0             3877            4227            129
PACKED_TO_PADDED_TORCH_10_1000_3000_1_cpu                2789            3054            180
PACKED_TO_PADDED_TORCH_10_1000_3000_1_cuda:0             3821            4402            131
PACKED_TO_PADDED_TORCH_10_1000_3000_16_cpu              11967           12963             42
PACKED_TO_PADDED_TORCH_10_1000_3000_16_cuda:0            3729            4290            135
PACKED_TO_PADDED_TORCH_32_100_300_1_cpu                  6933            8152             73
PACKED_TO_PADDED_TORCH_32_100_300_1_cuda:0              11856           12287             43
PACKED_TO_PADDED_TORCH_32_100_300_16_cpu                 9895           11205             51
PACKED_TO_PADDED_TORCH_32_100_300_16_cuda:0             12354           13596             41
PACKED_TO_PADDED_TORCH_32_100_3000_1_cpu                 9516           10128             53
PACKED_TO_PADDED_TORCH_32_100_3000_1_cuda:0             12917           13597             39
PACKED_TO_PADDED_TORCH_32_100_3000_16_cpu               41209           43783             13
PACKED_TO_PADDED_TORCH_32_100_3000_16_cuda:0            12210           13288             41
PACKED_TO_PADDED_TORCH_32_1000_300_1_cpu                 7179            7689             70
PACKED_TO_PADDED_TORCH_32_1000_300_1_cuda:0             11896           12381             43
PACKED_TO_PADDED_TORCH_32_1000_300_16_cpu               10127           15494             50
PACKED_TO_PADDED_TORCH_32_1000_300_16_cuda:0            12034           12817             42
PACKED_TO_PADDED_TORCH_32_1000_3000_1_cpu                8743           10251             58
PACKED_TO_PADDED_TORCH_32_1000_3000_1_cuda:0            12023           12908             42
PACKED_TO_PADDED_TORCH_32_1000_3000_16_cpu              39071           41777             13
PACKED_TO_PADDED_TORCH_32_1000_3000_16_cuda:0           11999           13690             42
--------------------------------------------------------------------------------
```

Reviewed By: bottler, nikhilaravi, jcjohnson

Differential Revision: D19870575

fbshipit-source-id: 23a2477b73373c411899633386c87ab034c3702a
This commit is contained in:
Georgia Gkioxari
2020-02-19 10:46:51 -08:00
committed by Facebook Github Bot
parent 8301163d24
commit 60f3c4e7d2
13 changed files with 850 additions and 158 deletions

View File

@@ -10,7 +10,8 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("face_areas_normals", &FaceAreasNormals);
m.def("packed_to_padded_tensor", &packed_to_padded_tensor);
m.def("packed_to_padded", &PackedToPadded);
m.def("padded_to_packed", &PaddedToPacked);
m.def("nn_points_idx", &NearestNeighborIdx);
m.def("gather_scatter", &gather_scatter);
m.def("rasterize_points", &RasterizePoints);

View File

@@ -21,10 +21,12 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCpu(
at::Tensor verts,
at::Tensor faces);
#ifdef WITH_CUDA
// Cuda implementation.
std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsCuda(
at::Tensor verts,
at::Tensor faces);
#endif
// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> FaceAreasNormals(

View File

@@ -22,8 +22,10 @@
// CPU implementation.
at::Tensor NearestNeighborIdxCpu(at::Tensor p1, at::Tensor p2);
#ifdef WITH_CUDA
// Cuda implementation.
at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2);
#endif
// Implementation which is exposed.
at::Tensor NearestNeighborIdx(at::Tensor p1, at::Tensor p2) {

View File

@@ -1,11 +1,42 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <torch/extension.h>
// Kernel for inputs_packed of shape (F, D), where D > 1
template <typename scalar_t>
__global__ void packed_to_padded_tensor_kernel(
const scalar_t* __restrict__ inputs,
const long* __restrict__ first_idxs,
__global__ void PackedToPaddedKernel(
const scalar_t* __restrict__ inputs_packed,
const int64_t* __restrict__ first_idxs,
scalar_t* __restrict__ inputs_padded,
const size_t batch_size,
const size_t max_size,
const size_t num_inputs,
const size_t D) {
// Batch elements split evenly across blocks (num blocks = batch_size) and
// values for each element split across threads in the block. Each thread adds
// the values of its respective input elements to the global inputs_padded
// tensor.
const size_t tid = threadIdx.x;
const size_t batch_idx = blockIdx.x;
const int64_t start = first_idxs[batch_idx];
const int64_t end =
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
const int num = end - start;
for (size_t f = tid; f < num; f += blockDim.x) {
for (size_t j = 0; j < D; ++j) {
inputs_padded[batch_idx * max_size * D + f * D + j] =
inputs_packed[(start + f) * D + j];
}
}
}
// Kernel for inputs of shape (F, 1)
template <typename scalar_t>
__global__ void PackedToPaddedKernelD1(
const scalar_t* __restrict__ inputs_packed,
const int64_t* __restrict__ first_idxs,
scalar_t* __restrict__ inputs_padded,
const size_t batch_size,
const size_t max_size,
@@ -17,36 +48,155 @@ __global__ void packed_to_padded_tensor_kernel(
const size_t tid = threadIdx.x;
const size_t batch_idx = blockIdx.x;
const long start = first_idxs[batch_idx];
const long end =
const int64_t start = first_idxs[batch_idx];
const int64_t end =
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
const int num_faces = end - start;
for (size_t f = tid; f < num_faces; f += blockDim.x) {
inputs_padded[batch_idx * max_size + f] = inputs[start + f];
const int num = end - start;
for (size_t f = tid; f < num; f += blockDim.x) {
inputs_padded[batch_idx * max_size + f] = inputs_packed[start + f];
}
}
at::Tensor packed_to_padded_tensor_cuda(
at::Tensor inputs,
at::Tensor first_idxs,
const long max_size) {
const auto num_inputs = inputs.size(0);
const auto batch_size = first_idxs.size(0);
// Kernel for inputs_padded of shape (B, F, D), where D > 1
template <typename scalar_t>
__global__ void PaddedToPackedKernel(
const scalar_t* __restrict__ inputs_padded,
const int64_t* __restrict__ first_idxs,
scalar_t* __restrict__ inputs_packed,
const size_t batch_size,
const size_t max_size,
const size_t num_inputs,
const size_t D) {
// Batch elements split evenly across blocks (num blocks = batch_size) and
// values for each element split across threads in the block. Each thread adds
// the values of its respective input elements to the global inputs_packed
// tensor.
const size_t tid = threadIdx.x;
const size_t batch_idx = blockIdx.x;
const int64_t start = first_idxs[batch_idx];
const int64_t end =
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
const int num = end - start;
for (size_t f = tid; f < num; f += blockDim.x) {
for (size_t j = 0; j < D; ++j) {
inputs_packed[(start + f) * D + j] =
inputs_padded[batch_idx * max_size * D + f * D + j];
}
}
}
// Kernel for inputs_padded of shape (B, F, 1)
template <typename scalar_t>
__global__ void PaddedToPackedKernelD1(
const scalar_t* __restrict__ inputs_padded,
const int64_t* __restrict__ first_idxs,
scalar_t* __restrict__ inputs_packed,
const size_t batch_size,
const size_t max_size,
const size_t num_inputs) {
// Batch elements split evenly across blocks (num blocks = batch_size) and
// values for each element split across threads in the block. Each thread adds
// the values of its respective input elements to the global inputs_packed
// tensor.
const size_t tid = threadIdx.x;
const size_t batch_idx = blockIdx.x;
const int64_t start = first_idxs[batch_idx];
const int64_t end =
batch_idx + 1 < batch_size ? first_idxs[batch_idx + 1] : num_inputs;
const int num = end - start;
for (size_t f = tid; f < num; f += blockDim.x) {
inputs_packed[start + f] = inputs_padded[batch_idx * max_size + f];
}
}
at::Tensor PackedToPaddedCuda(
const at::Tensor inputs_packed,
const at::Tensor first_idxs,
const int64_t max_size) {
const int64_t num_inputs = inputs_packed.size(0);
const int64_t batch_size = first_idxs.size(0);
AT_ASSERTM(
inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
const int64_t D = inputs_packed.size(1);
at::Tensor inputs_padded =
at::zeros({batch_size, max_size}, inputs.options());
at::zeros({batch_size, max_size, D}, inputs_packed.options());
const int threads = 512;
const int blocks = batch_size;
AT_DISPATCH_FLOATING_TYPES(
inputs.type(), "packed_to_padded_tensor_kernel", ([&] {
packed_to_padded_tensor_kernel<scalar_t><<<blocks, threads>>>(
inputs.data_ptr<scalar_t>(),
first_idxs.data_ptr<long>(),
inputs_padded.data_ptr<scalar_t>(),
batch_size,
max_size,
num_inputs);
}));
if (D == 1) {
AT_DISPATCH_FLOATING_TYPES(
inputs_packed.type(), "packed_to_padded_d1_kernel", ([&] {
PackedToPaddedKernelD1<scalar_t><<<blocks, threads>>>(
inputs_packed.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(),
inputs_padded.data_ptr<scalar_t>(),
batch_size,
max_size,
num_inputs);
}));
} else {
AT_DISPATCH_FLOATING_TYPES(
inputs_packed.type(), "packed_to_padded_kernel", ([&] {
PackedToPaddedKernel<scalar_t><<<blocks, threads>>>(
inputs_packed.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(),
inputs_padded.data_ptr<scalar_t>(),
batch_size,
max_size,
num_inputs,
D);
}));
}
return inputs_padded;
}
at::Tensor PaddedToPackedCuda(
const at::Tensor inputs_padded,
const at::Tensor first_idxs,
const int64_t num_inputs) {
const int64_t batch_size = inputs_padded.size(0);
const int64_t max_size = inputs_padded.size(1);
AT_ASSERTM(batch_size == first_idxs.size(0), "sizes mismatch");
AT_ASSERTM(
inputs_padded.dim() == 3,
"inputs_padded must be a 3-dimensional tensor");
const int64_t D = inputs_padded.size(2);
at::Tensor inputs_packed =
at::zeros({num_inputs, D}, inputs_padded.options());
const int threads = 512;
const int blocks = batch_size;
if (D == 1) {
AT_DISPATCH_FLOATING_TYPES(
inputs_padded.type(), "padded_to_packed_d1_kernel", ([&] {
PaddedToPackedKernelD1<scalar_t><<<blocks, threads>>>(
inputs_padded.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(),
inputs_packed.data_ptr<scalar_t>(),
batch_size,
max_size,
num_inputs);
}));
} else {
AT_DISPATCH_FLOATING_TYPES(
inputs_padded.type(), "padded_to_packed_kernel", ([&] {
PaddedToPackedKernel<scalar_t><<<blocks, threads>>>(
inputs_padded.data_ptr<scalar_t>(),
first_idxs.data_ptr<int64_t>(),
inputs_packed.data_ptr<scalar_t>(),
batch_size,
max_size,
num_inputs,
D);
}));
}
return inputs_packed;
}

View File

@@ -3,42 +3,96 @@
#pragma once
#include <torch/extension.h>
// PackedToPadded
// Converts a packed tensor into a padded tensor, restoring the batch dimension.
// Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors.
//
// Inputs:
// inputs: FloatTensor of shape (F,), representing the packed batch tensor.
// e.g. areas for faces in a batch of meshes.
// inputs_packed: FloatTensor of shape (F, D), representing the packed batch
// tensor, e.g. areas for faces in a batch of meshes.
// first_idxs: LongTensor of shape (N,) where N is the number of
// elements in the batch and `packed_first_idxs[i] = f`
// elements in the batch and `first_idxs[i] = f`
// means that the inputs for batch element i begin at
// `inputs[f]`.
// max_size: Max length of an element in the batch.
// max_size: Max length of an element in the batch.
// Returns:
// inputs_padded: FloatTensor of shape (N, max_size) where max_size is max
// inputs_padded: FloatTensor of shape (N, max_size, D) where max_size is max
// of `sizes`. The values for batch element i which start at
// `inputs[packed_first_idxs[i]]` will be copied to
// `inputs_padded[i, :]``, with zeros padding out the extra
// `inputs_packed[first_idxs[i]]` will be copied to
// `inputs_padded[i, :]`, with zeros padding out the extra
// inputs.
//
// PaddedToPacked
// Converts a padded tensor into a packed tensor.
// Refer to pytorch3d/structures/meshes.py for details on packed/padded tensors.
//
// Inputs:
// inputs_padded: FloatTensor of shape (N, max_size, D), representing the
// padded tensor, e.g. areas for faces in a batch of meshes.
// first_idxs: LongTensor of shape (N,) where N is the number of
// elements in the batch and `first_idxs[i] = f`
// means that the inputs for batch element i begin at
// `inputs_packed[f]`.
// num_inputs: Number of packed entries (= F)
// Returns:
// inputs_packed: FloatTensor of shape (F, D), where
// `inputs_packed[first_idx[i]:] = inputs_padded[i, :]`.
//
//
// Cpu implementation.
at::Tensor PackedToPaddedCpu(
const at::Tensor inputs_packed,
const at::Tensor first_idxs,
const int64_t max_size);
// Cpu implementation.
at::Tensor PaddedToPackedCpu(
const at::Tensor inputs_padded,
const at::Tensor first_idxs,
const int64_t num_inputs);
#ifdef WITH_CUDA
// Cuda implementation.
at::Tensor packed_to_padded_tensor_cuda(
at::Tensor inputs,
at::Tensor first_idxs,
const long max_size);
at::Tensor PackedToPaddedCuda(
const at::Tensor inputs_packed,
const at::Tensor first_idxs,
const int64_t max_size);
// Cuda implementation.
at::Tensor PaddedToPackedCuda(
const at::Tensor inputs_padded,
const at::Tensor first_idxs,
const int64_t num_inputs);
#endif
// Implementation which is exposed.
at::Tensor packed_to_padded_tensor(
at::Tensor inputs,
at::Tensor first_idxs,
const long max_size) {
if (inputs.type().is_cuda()) {
at::Tensor PackedToPadded(
const at::Tensor inputs_packed,
const at::Tensor first_idxs,
const int64_t max_size) {
if (inputs_packed.type().is_cuda()) {
#ifdef WITH_CUDA
return packed_to_padded_tensor_cuda(inputs, first_idxs, max_size);
return PackedToPaddedCuda(inputs_packed, first_idxs, max_size);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("Not implemented on the CPU.");
return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
}
// Implementation which is exposed.
at::Tensor PaddedToPacked(
const at::Tensor inputs_padded,
const at::Tensor first_idxs,
const int64_t num_inputs) {
if (inputs_padded.type().is_cuda()) {
#ifdef WITH_CUDA
return PaddedToPackedCuda(inputs_padded, first_idxs, num_inputs);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
}

View File

@@ -0,0 +1,65 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
at::Tensor PackedToPaddedCpu(
const at::Tensor inputs_packed,
const at::Tensor first_idxs,
const int64_t max_size) {
const int64_t num_inputs = inputs_packed.size(0);
const int64_t batch_size = first_idxs.size(0);
AT_ASSERTM(
inputs_packed.dim() == 2, "inputs_packed must be a 2-dimensional tensor");
const int64_t D = inputs_packed.size(1);
torch::Tensor inputs_padded =
torch::zeros({batch_size, max_size, D}, inputs_packed.options());
auto inputs_packed_a = inputs_packed.accessor<float, 2>();
auto first_idxs_a = first_idxs.accessor<int64_t, 1>();
auto inputs_padded_a = inputs_padded.accessor<float, 3>();
for (int b = 0; b < batch_size; ++b) {
const int64_t start = first_idxs_a[b];
const int64_t end = b + 1 < batch_size ? first_idxs_a[b + 1] : num_inputs;
const int64_t num = end - start;
for (int i = 0; i < num; ++i) {
for (int j = 0; j < D; ++j) {
inputs_padded_a[b][i][j] = inputs_packed_a[start + i][j];
}
}
}
return inputs_padded;
}
at::Tensor PaddedToPackedCpu(
const at::Tensor inputs_padded,
const at::Tensor first_idxs,
const int64_t num_inputs) {
const int64_t batch_size = inputs_padded.size(0);
const int64_t max_size = inputs_padded.size(1);
AT_ASSERTM(
inputs_padded.dim() == 3, "inputs_padded must be a 3-dimensional tensor");
const int64_t D = inputs_padded.size(2);
torch::Tensor inputs_packed =
torch::zeros({num_inputs, D}, inputs_padded.options());
auto inputs_padded_a = inputs_padded.accessor<float, 3>();
auto first_idxs_a = first_idxs.accessor<int64_t, 1>();
auto inputs_packed_a = inputs_packed.accessor<float, 2>();
for (int b = 0; b < batch_size; ++b) {
const int64_t start = first_idxs_a[b];
const int64_t end = b + 1 < batch_size ? first_idxs_a[b + 1] : num_inputs;
const int64_t num = end - start;
for (int i = 0; i < num; ++i) {
for (int j = 0; j < D; ++j) {
inputs_packed_a[start + i][j] = inputs_padded_a[b][i][j];
}
}
}
return inputs_packed;
}