mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Fixed windows MSVC build compatibility (#9)
Summary: Fixed a few MSVC compiler (visual studio 2019, MSVC 19.16.27034) compatibility issues 1. Replaced long with int64_t. aten::data_ptr\<long\> is not supported in MSVC 2. pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp, inline function is not correctly recognized by MSVC. 3. pytorch3d/csrc/rasterize_meshes/geometry_utils.cuh const auto kEpsilon = 1e-30; MSVC does not compile this const into both host and device, change to a MACRO. 4. pytorch3d/csrc/rasterize_meshes/geometry_utils.cuh, const float area2 = pow(area, 2.0); 2.0 is considered as double by MSVC and raised an error 5. pytorch3d/csrc/rasterize_points/rasterize_points_cpu.cpp std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu() return type does not match the declaration in rasterize_points_cpu.h. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/9 Reviewed By: nikhilaravi Differential Revision: D19986567 Pulled By: yuanluxu fbshipit-source-id: f4d98525d088c99c513b85193db6f0fc69c7f017
This commit is contained in:
parent
a3baa367e3
commit
9e21659fc5
40
INSTALL.md
40
INSTALL.md
@ -7,7 +7,7 @@
|
||||
|
||||
The core library is written in PyTorch. Several components have underlying implementation in CUDA for improved performance. A subset of these components have CPU implementations in C++/Pytorch. It is advised to use PyTorch3d with GPU support in order to use all the features.
|
||||
|
||||
- Linux or macOS
|
||||
- Linux or macOS or Windows
|
||||
- Python ≥ 3.6
|
||||
- PyTorch 1.4
|
||||
- torchvision that matches the PyTorch installation. You can install them together at pytorch.org to make sure of this.
|
||||
@ -72,3 +72,41 @@ To rebuild after installing from a local clone run, `rm -rf build/ **/*.so` then
|
||||
```
|
||||
MACOSX_DEPLOYMENT_TARGET=10.14 CC=clang CXX=clang++ pip install -e .
|
||||
```
|
||||
|
||||
**Install from local clone on Windows:**
|
||||
|
||||
If you are using pre-compiled pytorch 1.4 and torchvision 0.5, you should make the following changes to the pytorch source code to successfully compile with Visual Studio 2019 (MSVC 19.16.27034) and CUDA 10.1.
|
||||
|
||||
Change python/Lib/site-packages/torch/include/csrc/jit/script/module.h
|
||||
|
||||
L466, 476, 493, 506, 536
|
||||
```
|
||||
-static constexpr *
|
||||
+static const *
|
||||
```
|
||||
Change python/Lib/site-packages/torch/include/csrc/jit/argument_spec.h
|
||||
|
||||
L190
|
||||
```
|
||||
-static constexpr size_t DEPTH_LIMIT = 128;
|
||||
+static const size_t DEPTH_LIMIT = 128;
|
||||
```
|
||||
|
||||
Change python/Lib/site-packages/torch/include/pybind11/cast.h
|
||||
|
||||
L1449
|
||||
```
|
||||
-explicit operator type&() { return *(this->value); }
|
||||
+explicit operator type& () { return *((type*)(this->value)); }
|
||||
```
|
||||
|
||||
After patching, you can go to "x64 Native Tools Command Prompt for VS 2019" to compile and install
|
||||
```
|
||||
cd pytorch3d
|
||||
python3 setup.py install
|
||||
```
|
||||
After installing, verify whether all unit tests have passed
|
||||
```
|
||||
cd tests
|
||||
python3 -m unittest discover -p *.py
|
||||
```
|
@ -5,7 +5,7 @@
|
||||
// TODO(T47953967) to make this cuda kernel support all datatypes.
|
||||
__global__ void gather_scatter_kernel(
|
||||
const float* __restrict__ input,
|
||||
const long* __restrict__ edges,
|
||||
const int64_t* __restrict__ edges,
|
||||
float* __restrict__ output,
|
||||
bool directed,
|
||||
bool backward,
|
||||
@ -21,8 +21,8 @@ __global__ void gather_scatter_kernel(
|
||||
// Edges are split evenly across the blocks.
|
||||
for (int e = blockIdx.x; e < E; e += gridDim.x) {
|
||||
// Get indices of vertices which form the edge.
|
||||
const long v0 = edges[2 * e + v0_idx];
|
||||
const long v1 = edges[2 * e + v1_idx];
|
||||
const int64_t v0 = edges[2 * e + v0_idx];
|
||||
const int64_t v1 = edges[2 * e + v1_idx];
|
||||
|
||||
// Split vertex features evenly across threads.
|
||||
// This implementation will be quite wasteful when D<128 since there will be
|
||||
@ -57,7 +57,7 @@ at::Tensor gather_scatter_cuda(
|
||||
|
||||
gather_scatter_kernel<<<blocks, threads>>>(
|
||||
input.data_ptr<float>(),
|
||||
edges.data_ptr<long>(),
|
||||
edges.data_ptr<int64_t>(),
|
||||
output.data_ptr<float>(),
|
||||
directed,
|
||||
backward,
|
||||
|
@ -6,7 +6,7 @@
|
||||
template <typename scalar_t>
|
||||
__device__ void WarpReduce(
|
||||
volatile scalar_t* min_dists,
|
||||
volatile long* min_idxs,
|
||||
volatile int64_t* min_idxs,
|
||||
const size_t tid) {
|
||||
// s = 32
|
||||
if (min_dists[tid] > min_dists[tid + 32]) {
|
||||
@ -57,7 +57,7 @@ template <typename scalar_t>
|
||||
__global__ void NearestNeighborKernel(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
long* __restrict__ idx,
|
||||
int64_t* __restrict__ idx,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
@ -74,7 +74,7 @@ __global__ void NearestNeighborKernel(
|
||||
extern __shared__ char shared_buf[];
|
||||
scalar_t* x = (scalar_t*)shared_buf; // scalar_t[DD]
|
||||
scalar_t* min_dists = &x[D_2]; // scalar_t[NUM_THREADS]
|
||||
long* min_idxs = (long*)&min_dists[blockDim.x]; // long[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t n = blockIdx.y; // index of batch element.
|
||||
const size_t i = blockIdx.x; // index of point within batch element.
|
||||
@ -147,14 +147,14 @@ template <typename scalar_t>
|
||||
__global__ void NearestNeighborKernelD3(
|
||||
const scalar_t* __restrict__ points1,
|
||||
const scalar_t* __restrict__ points2,
|
||||
long* __restrict__ idx,
|
||||
int64_t* __restrict__ idx,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2) {
|
||||
// Single shared memory buffer which is split and cast to different types.
|
||||
extern __shared__ char shared_buf[];
|
||||
scalar_t* min_dists = (scalar_t*)shared_buf; // scalar_t[NUM_THREADS]
|
||||
long* min_idxs = (long*)&min_dists[blockDim.x]; // long[NUM_THREADS]
|
||||
int64_t* min_idxs = (int64_t*)&min_dists[blockDim.x]; // int64_t[NUM_THREADS]
|
||||
|
||||
const size_t D = 3;
|
||||
const size_t n = blockIdx.y; // index of batch element.
|
||||
@ -230,12 +230,12 @@ at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
|
||||
// Use the specialized kernel for D=3.
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.type(), "nearest_neighbor_v3_cuda", ([&] {
|
||||
size_t shared_size = threads * sizeof(size_t) +
|
||||
threads * sizeof(long);
|
||||
threads * sizeof(int64_t);
|
||||
NearestNeighborKernelD3<scalar_t>
|
||||
<<<blocks, threads, shared_size>>>(
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
idx.data_ptr<long>(),
|
||||
idx.data_ptr<int64_t>(),
|
||||
N,
|
||||
P1,
|
||||
P2);
|
||||
@ -248,11 +248,11 @@ at::Tensor NearestNeighborIdxCuda(at::Tensor p1, at::Tensor p2) {
|
||||
// need to be rounded to the next even size.
|
||||
size_t D_2 = D + (D % 2);
|
||||
size_t shared_size = (D_2 + threads) * sizeof(size_t);
|
||||
shared_size += threads * sizeof(long);
|
||||
shared_size += threads * sizeof(int64_t);
|
||||
NearestNeighborKernel<scalar_t><<<blocks, threads, shared_size>>>(
|
||||
p1.data_ptr<scalar_t>(),
|
||||
p2.data_ptr<scalar_t>(),
|
||||
idx.data_ptr<long>(),
|
||||
idx.data_ptr<int64_t>(),
|
||||
N,
|
||||
P1,
|
||||
P2,
|
||||
|
@ -7,7 +7,11 @@
|
||||
#include "float_math.cuh"
|
||||
|
||||
// Set epsilon for preventing floating point errors and division by 0.
|
||||
#ifdef _MSC_VER
|
||||
#define kEpsilon 1e-30f
|
||||
#else
|
||||
const auto kEpsilon = 1e-30;
|
||||
#endif
|
||||
|
||||
// Determines whether a point p is on the right side of a 2D line segment
|
||||
// given by the end points v0, v1.
|
||||
@ -93,7 +97,7 @@ BarycentricCoordsBackward(
|
||||
const float2& v2,
|
||||
const float3& grad_bary_upstream) {
|
||||
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
|
||||
const float area2 = pow(area, 2.0);
|
||||
const float area2 = pow(area, 2.0f);
|
||||
const float e0 = EdgeFunctionForward(p, v1, v2);
|
||||
const float e1 = EdgeFunctionForward(p, v2, v0);
|
||||
const float e2 = EdgeFunctionForward(p, v0, v1);
|
||||
|
@ -7,7 +7,7 @@
|
||||
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
|
||||
// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized
|
||||
// pixels, and assume that each pixel falls in the *center* of its range.
|
||||
inline float PixToNdc(const int i, const int S) {
|
||||
static float PixToNdc(const int i, const int S) {
|
||||
// NDC x-offset + (i * pixel_width + half_pixel_width)
|
||||
return -1 + (2 * i + 1.0f) / S;
|
||||
}
|
||||
@ -74,7 +74,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
|
||||
return std::make_tuple(point_idxs, zbuf, pix_dists);
|
||||
}
|
||||
|
||||
std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu(
|
||||
torch::Tensor RasterizePointsCoarseCpu(
|
||||
const torch::Tensor& points,
|
||||
const int image_size,
|
||||
const float radius,
|
||||
@ -140,7 +140,7 @@ std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu(
|
||||
bin_y_max = bin_y_min + bin_width;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(points_per_bin, bin_points);
|
||||
return bin_points;
|
||||
}
|
||||
|
||||
torch::Tensor RasterizePointsBackwardCpu(
|
||||
|
Loading…
x
Reference in New Issue
Block a user