PyTorch 1.7 compatibility

Summary: Small changes discovered based on circleCI failures.

Reviewed By: patricklabatut

Differential Revision: D34426807

fbshipit-source-id: 819860f34b2f367dd24057ca7490284204180a13
This commit is contained in:
Jeremy Reizenstein
2022-02-25 07:53:34 -08:00
committed by Facebook GitHub Bot
parent f816568735
commit 4d043fc9ac
9 changed files with 34 additions and 23 deletions

View File

@@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
from typing import Sequence, Tuple, Union
import torch
@@ -57,4 +57,16 @@ def eigh(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no co
"""
if hasattr(torch, "linalg") and hasattr(torch.linalg, "eigh"):
return torch.linalg.eigh(A)
return torch.symeig(A, eigenvalues=True)
return torch.symeig(A, eigenvectors=True)
def meshgrid_ij(
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
"""
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
"""
if "indexing" in torch.meshgrid.__kwdefaults__:
# PyTorch >= 1.10.0
return torch.meshgrid(*A, indexing="ij")
return torch.meshgrid(*A)

View File

@@ -13,6 +13,7 @@ import numpy as np
import torch
import torch.nn.functional as F
from iopath.common.file_io import PathManager
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.common.datatypes import Device
from pytorch3d.io.utils import _open_file, _read_image
@@ -273,7 +274,7 @@ def make_material_atlas(
# Meshgrid returns (row, column) i.e (Y, X)
# Change order to (X, Y) to make the grid.
Y, X = torch.meshgrid(rng, rng)
Y, X = meshgrid_ij(rng, rng)
# pyre-fixme[28]: Unexpected keyword argument `axis`.
grid = torch.stack([X, Y], axis=-1) # (R, R, 2)

View File

@@ -7,6 +7,7 @@
import torch
import torch.nn.functional as F
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.structures import Meshes
@@ -195,9 +196,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
# NF x 3
grid_faces = torch.stack(grid_faces, dim=1)
y, x, z = torch.meshgrid(
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
)
y, x, z = meshgrid_ij(torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1))
y = y.to(device=device, dtype=torch.float32)
x = x.to(device=device, dtype=torch.float32)
z = z.to(device=device, dtype=torch.float32)

View File

@@ -8,6 +8,7 @@ import warnings
from typing import Optional
import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit.utils import RayBundle
from torch.nn import functional as F
@@ -103,7 +104,7 @@ class MultinomialRaysampler(torch.nn.Module):
_xy_grid = torch.stack(
tuple(
reversed(
torch.meshgrid(
meshgrid_ij(
torch.linspace(min_y, max_y, image_height, dtype=torch.float32),
torch.linspace(min_x, max_x, image_width, dtype=torch.float32),
)

View File

@@ -8,9 +8,10 @@ import copy
from typing import List, Optional, Tuple, Union
import torch
from pytorch3d.common.compat import meshgrid_ij
from pytorch3d.common.datatypes import Device, make_device
from pytorch3d.transforms import Scale, Transform3d
from ..common.datatypes import Device, make_device
from ..transforms import Scale, Transform3d
from . import utils as struct_utils
@@ -393,7 +394,7 @@ class Volumes:
]
# generate per-coord meshgrids
Z, Y, X = torch.meshgrid(vol_axes)
Z, Y, X = meshgrid_ij(vol_axes)
# stack the coord grids ... this order matches the coordinate convention
# of torch.nn.grid_sample