From 4d043fc9ac79d917ec1a72a23aa66a1589c58c7c Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 25 Feb 2022 07:53:34 -0800 Subject: [PATCH] PyTorch 1.7 compatibility Summary: Small changes discovered based on circleCI failures. Reviewed By: patricklabatut Differential Revision: D34426807 fbshipit-source-id: 819860f34b2f367dd24057ca7490284204180a13 --- pytorch3d/common/compat.py | 16 ++++++++++++++-- pytorch3d/io/mtl_io.py | 3 ++- pytorch3d/ops/cubify.py | 5 ++--- pytorch3d/renderer/implicit/raysampling.py | 3 ++- pytorch3d/structures/volumes.py | 7 ++++--- tests/test_point_mesh_distance.py | 6 +----- tests/test_pointclouds.py | 7 +++---- tests/test_raysampling.py | 5 +++-- tests/test_rendering_utils.py | 5 +++-- 9 files changed, 34 insertions(+), 23 deletions(-) diff --git a/pytorch3d/common/compat.py b/pytorch3d/common/compat.py index e15a6107..1b81c386 100644 --- a/pytorch3d/common/compat.py +++ b/pytorch3d/common/compat.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Tuple +from typing import Sequence, Tuple, Union import torch @@ -57,4 +57,16 @@ def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no co """ if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"): return torch.linalg.eigh(A) - return torch.symeig(A, eigenvalues=True) + return torch.symeig(A, eigenvectors=True) + + +def meshgrid_ij( + *A: Union[torch.Tensor, Sequence[torch.Tensor]] +) -> Tuple[torch.Tensor, ...]: # pragma: no cover + """ + Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij + """ + if "indexing" in torch.meshgrid.__kwdefaults__: + # PyTorch >= 1.10.0 + return torch.meshgrid(*A, indexing="ij") + return torch.meshgrid(*A) diff --git a/pytorch3d/io/mtl_io.py b/pytorch3d/io/mtl_io.py index 7429ec16..40de8259 100644 --- a/pytorch3d/io/mtl_io.py +++ b/pytorch3d/io/mtl_io.py @@ -13,6 +13,7 @@ import numpy as np import torch import torch.nn.functional as F from iopath.common.file_io import PathManager +from pytorch3d.common.compat import meshgrid_ij from pytorch3d.common.datatypes import Device from pytorch3d.io.utils import _open_file, _read_image @@ -273,7 +274,7 @@ def make_material_atlas( # Meshgrid returns (row, column) i.e (Y, X) # Change order to (X, Y) to make the grid. - Y, X = torch.meshgrid(rng, rng) + Y, X = meshgrid_ij(rng, rng) # pyre-fixme[28]: Unexpected keyword argument `axis`. grid = torch.stack([X, Y], axis=-1) # (R, R, 2) diff --git a/pytorch3d/ops/cubify.py b/pytorch3d/ops/cubify.py index b5105134..7af449b2 100644 --- a/pytorch3d/ops/cubify.py +++ b/pytorch3d/ops/cubify.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F +from pytorch3d.common.compat import meshgrid_ij from pytorch3d.structures import Meshes @@ -195,9 +196,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes: # NF x 3 grid_faces = torch.stack(grid_faces, dim=1) - y, x, z = torch.meshgrid( - torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1) - ) + y, x, z = meshgrid_ij(torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)) y = y.to(device=device, dtype=torch.float32) x = x.to(device=device, dtype=torch.float32) z = z.to(device=device, dtype=torch.float32) diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index e868a18d..b73a94cd 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -8,6 +8,7 @@ import warnings from typing import Optional import torch +from pytorch3d.common.compat import meshgrid_ij from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.implicit.utils import RayBundle from torch.nn import functional as F @@ -103,7 +104,7 @@ class MultinomialRaysampler(torch.nn.Module): _xy_grid = torch.stack( tuple( reversed( - torch.meshgrid( + meshgrid_ij( torch.linspace(min_y, max_y, image_height, dtype=torch.float32), torch.linspace(min_x, max_x, image_width, dtype=torch.float32), ) diff --git a/pytorch3d/structures/volumes.py b/pytorch3d/structures/volumes.py index d96964f9..d9cd5ad4 100644 --- a/pytorch3d/structures/volumes.py +++ b/pytorch3d/structures/volumes.py @@ -8,9 +8,10 @@ import copy from typing import List, Optional, Tuple, Union import torch +from pytorch3d.common.compat import meshgrid_ij +from pytorch3d.common.datatypes import Device, make_device +from pytorch3d.transforms import Scale, Transform3d -from ..common.datatypes import Device, make_device -from ..transforms import Scale, Transform3d from . import utils as struct_utils @@ -393,7 +394,7 @@ class Volumes: ] # generate per-coord meshgrids - Z, Y, X = torch.meshgrid(vol_axes) + Z, Y, X = meshgrid_ij(vol_axes) # stack the coord grids ... this order matches the coordinate convention # of torch.nn.grid_sample diff --git a/tests/test_point_mesh_distance.py b/tests/test_point_mesh_distance.py index 2976db15..af6dd854 100644 --- a/tests/test_point_mesh_distance.py +++ b/tests/test_point_mesh_distance.py @@ -11,11 +11,7 @@ import torch from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d import _C from pytorch3d.loss import point_mesh_edge_distance, point_mesh_face_distance -from pytorch3d.structures import ( - Meshes, - Pointclouds, - packed_to_list, -) +from pytorch3d.structures import Meshes, Pointclouds, packed_to_list class TestPointMeshDistance(TestCaseMixin, unittest.TestCase): diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index fa37368f..52efee18 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -1033,7 +1033,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): for i, cloud in enumerate(clouds.points_list()): within_box_naive.append(inside_box_naive(cloud, box[i, 0], box[i, 1])) within_box_naive = torch.cat(within_box_naive, 0) - self.assertClose(within_box, within_box_naive) + self.assertTrue(torch.equal(within_box, within_box_naive)) # box of shape 2x3 box2 = box[0, :] @@ -1044,13 +1044,12 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): for cloud in clouds.points_list(): within_box_naive2.append(inside_box_naive(cloud, box2[0], box2[1])) within_box_naive2 = torch.cat(within_box_naive2, 0) - self.assertClose(within_box2, within_box_naive2) - + self.assertTrue(torch.equal(within_box2, within_box_naive2)) # box of shape 1x2x3 box3 = box2.expand(1, 2, 3) within_box3 = clouds.inside_box(box3) - self.assertClose(within_box2, within_box3) + self.assertTrue(torch.equal(within_box2, within_box3)) # invalid box invalid_box = torch.cat( diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py index cb99ed31..da43439e 100644 --- a/tests/test_raysampling.py +++ b/tests/test_raysampling.py @@ -9,6 +9,7 @@ from typing import Callable import torch from common_testing import TestCaseMixin +from pytorch3d.common.compat import meshgrid_ij from pytorch3d.ops import eyes from pytorch3d.renderer import ( MonteCarloRaysampler, @@ -86,7 +87,7 @@ class TestNDCRaysamplerConvention(TestCaseMixin, unittest.TestCase): min_y = range_y - half_pix_height max_y = -range_y + half_pix_height - y_grid, x_grid = torch.meshgrid( + y_grid, x_grid = meshgrid_ij( torch.linspace(min_y, max_y, h, dtype=torch.float32), torch.linspace(min_x, max_x, w, dtype=torch.float32), ) @@ -540,7 +541,7 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): self.assertTupleEqual(out.shape, data.shape) # Check `out` is in ascending order - self.assertGreater(torch.diff(out, dim=-1).min(), 0) + self.assertGreater((out[..., 1:] - out[..., :-1]).min(), 0) self.assertConstant(out[..., :-1] < data[..., 1:], True) self.assertConstant(data[..., :-1] < out[..., 1:], True) diff --git a/tests/test_rendering_utils.py b/tests/test_rendering_utils.py index 7589b244..3a7c668d 100644 --- a/tests/test_rendering_utils.py +++ b/tests/test_rendering_utils.py @@ -10,6 +10,7 @@ import unittest import numpy as np import torch from common_testing import TestCaseMixin +from pytorch3d.common.compat import meshgrid_ij from pytorch3d.ops import eyes from pytorch3d.renderer import ( AlphaCompositor, @@ -129,8 +130,8 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase): point_radius = 0.015 n_pts = n_grid_pts * n_grid_pts pts = torch.stack( - torch.meshgrid( - [torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2, indexing="ij" + meshgrid_ij( + [torch.linspace(-grid_scale, grid_scale, n_grid_pts)] * 2, ), dim=-1, )