mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
PyTorch 1.7 compatibility
Summary: Small changes discovered based on circleCI failures. Reviewed By: patricklabatut Differential Revision: D34426807 fbshipit-source-id: 819860f34b2f367dd24057ca7490284204180a13
This commit is contained in:
parent
f816568735
commit
4d043fc9ac
@ -4,7 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -57,4 +57,16 @@ def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no co
|
||||
"""
|
||||
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
|
||||
return torch.linalg.eigh(A)
|
||||
return torch.symeig(A, eigenvalues=True)
|
||||
return torch.symeig(A, eigenvectors=True)
|
||||
|
||||
|
||||
def meshgrid_ij(
|
||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
||||
"""
|
||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||
"""
|
||||
if "indexing" in torch.meshgrid.__kwdefaults__:
|
||||
# PyTorch >= 1.10.0
|
||||
return torch.meshgrid(*A, indexing="ij")
|
||||
return torch.meshgrid(*A)
|
||||
|
@ -13,6 +13,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.common.compat import meshgrid_ij
|
||||
from pytorch3d.common.datatypes import Device
|
||||
from pytorch3d.io.utils import _open_file, _read_image
|
||||
|
||||
@ -273,7 +274,7 @@ def make_material_atlas(
|
||||
|
||||
# Meshgrid returns (row, column) i.e (Y, X)
|
||||
# Change order to (X, Y) to make the grid.
|
||||
Y, X = torch.meshgrid(rng, rng)
|
||||
Y, X = meshgrid_ij(rng, rng)
|
||||
# pyre-fixme[28]: Unexpected keyword argument `axis`.
|
||||
grid = torch.stack([X, Y], axis=-1) # (R, R, 2)
|
||||
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.common.compat import meshgrid_ij
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
@ -195,9 +196,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
|
||||
# NF x 3
|
||||
grid_faces = torch.stack(grid_faces, dim=1)
|
||||
|
||||
y, x, z = torch.meshgrid(
|
||||
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
|
||||
)
|
||||
y, x, z = meshgrid_ij(torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1))
|
||||
y = y.to(device=device, dtype=torch.float32)
|
||||
x = x.to(device=device, dtype=torch.float32)
|
||||
z = z.to(device=device, dtype=torch.float32)
|
||||
|
@ -8,6 +8,7 @@ import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import meshgrid_ij
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit.utils import RayBundle
|
||||
from torch.nn import functional as F
|
||||
@ -103,7 +104,7 @@ class MultinomialRaysampler(torch.nn.Module):
|
||||
_xy_grid = torch.stack(
|
||||
tuple(
|
||||
reversed(
|
||||
torch.meshgrid(
|
||||
meshgrid_ij(
|
||||
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
|
||||
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
|
||||
)
|
||||
|
@ -8,9 +8,10 @@ import copy
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import meshgrid_ij
|
||||
from pytorch3d.common.datatypes import Device, make_device
|
||||
from pytorch3d.transforms import Scale, Transform3d
|
||||
|
||||
from ..common.datatypes import Device, make_device
|
||||
from ..transforms import Scale, Transform3d
|
||||
from . import utils as struct_utils
|
||||
|
||||
|
||||
@ -393,7 +394,7 @@ class Volumes:
|
||||
]
|
||||
|
||||
# generate per-coord meshgrids
|
||||
Z, Y, X = torch.meshgrid(vol_axes)
|
||||
Z, Y, X = meshgrid_ij(vol_axes)
|
||||
|
||||
# stack the coord grids ... this order matches the coordinate convention
|
||||
# of torch.nn.grid_sample
|
||||
|
@ -11,11 +11,7 @@ import torch
|
||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance
|
||||
from pytorch3d.structures import (
|
||||
Meshes,
|
||||
Pointclouds,
|
||||
packed_to_list,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, Pointclouds, packed_to_list
|
||||
|
||||
|
||||
class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
|
||||
|
@ -1033,7 +1033,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
for i, cloud in enumerate(clouds.points_list()):
|
||||
within_box_naive.append(inside_box_naive(cloud, box[i, 0], box[i, 1]))
|
||||
within_box_naive = torch.cat(within_box_naive, 0)
|
||||
self.assertClose(within_box, within_box_naive)
|
||||
self.assertTrue(torch.equal(within_box, within_box_naive))
|
||||
|
||||
# box of shape 2x3
|
||||
box2 = box[0, :]
|
||||
@ -1044,13 +1044,12 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
for cloud in clouds.points_list():
|
||||
within_box_naive2.append(inside_box_naive(cloud, box2[0], box2[1]))
|
||||
within_box_naive2 = torch.cat(within_box_naive2, 0)
|
||||
self.assertClose(within_box2, within_box_naive2)
|
||||
|
||||
self.assertTrue(torch.equal(within_box2, within_box_naive2))
|
||||
# box of shape 1x2x3
|
||||
box3 = box2.expand(1, 2, 3)
|
||||
|
||||
within_box3 = clouds.inside_box(box3)
|
||||
self.assertClose(within_box2, within_box3)
|
||||
self.assertTrue(torch.equal(within_box2, within_box3))
|
||||
|
||||
# invalid box
|
||||
invalid_box = torch.cat(
|
||||
|
@ -9,6 +9,7 @@ from typing import Callable
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.common.compat import meshgrid_ij
|
||||
from pytorch3d.ops import eyes
|
||||
from pytorch3d.renderer import (
|
||||
MonteCarloRaysampler,
|
||||
@ -86,7 +87,7 @@ class TestNDCRaysamplerConvention(TestCaseMixin, unittest.TestCase):
|
||||
min_y = range_y - half_pix_height
|
||||
max_y = -range_y + half_pix_height
|
||||
|
||||
y_grid, x_grid = torch.meshgrid(
|
||||
y_grid, x_grid = meshgrid_ij(
|
||||
torch.linspace(min_y, max_y, h, dtype=torch.float32),
|
||||
torch.linspace(min_x, max_x, w, dtype=torch.float32),
|
||||
)
|
||||
@ -540,7 +541,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTupleEqual(out.shape, data.shape)
|
||||
|
||||
# Check `out` is in ascending order
|
||||
self.assertGreater(torch.diff(out, dim=-1).min(), 0)
|
||||
self.assertGreater((out[..., 1:] - out[..., :-1]).min(), 0)
|
||||
|
||||
self.assertConstant(out[..., :-1] < data[..., 1:], True)
|
||||
self.assertConstant(data[..., :-1] < out[..., 1:], True)
|
||||
|
@ -10,6 +10,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.common.compat import meshgrid_ij
|
||||
from pytorch3d.ops import eyes
|
||||
from pytorch3d.renderer import (
|
||||
AlphaCompositor,
|
||||
@ -129,8 +130,8 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
|
||||
point_radius = 0.015
|
||||
n_pts = n_grid_pts * n_grid_pts
|
||||
pts = torch.stack(
|
||||
torch.meshgrid(
|
||||
[torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2, indexing="ij"
|
||||
meshgrid_ij(
|
||||
[torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user