PyTorch 1.4 compat

Summary: Restore compatibility with PyTorch 1.4 and 1.5, and a few lint fixes.

Reviewed By: patricklabatut

Differential Revision: D30048115

fbshipit-source-id: ee05efa7c625f6079fb06a3cc23be93e48df9433
This commit is contained in:
Jeremy Reizenstein 2021-08-03 08:09:39 -07:00 committed by Facebook GitHub Bot
parent 55aaec4d83
commit 5ecce83217
4 changed files with 21 additions and 20 deletions

View File

@ -19,7 +19,7 @@ def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
Like torch.linalg.solve, tries to return X Like torch.linalg.solve, tries to return X
such that AX=B, with A square. such that AX=B, with A square.
""" """
if hasattr(torch.linalg, "solve"): if hasattr(torch, "linalg") and hasattr(torch.linalg, "solve"):
# PyTorch version >= 1.8.0 # PyTorch version >= 1.8.0
return torch.linalg.solve(A, B) return torch.linalg.solve(A, B)
@ -31,7 +31,7 @@ def lstsq(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
Like torch.linalg.lstsq, tries to return X Like torch.linalg.lstsq, tries to return X
such that AX=B. such that AX=B.
""" """
if hasattr(torch.linalg, "lstsq"): if hasattr(torch, "linalg") and hasattr(torch.linalg, "lstsq"):
# PyTorch version >= 1.9 # PyTorch version >= 1.9
return torch.linalg.lstsq(A, B).solution return torch.linalg.lstsq(A, B).solution
@ -45,7 +45,7 @@ def qr(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cove
""" """
Like torch.linalg.qr. Like torch.linalg.qr.
""" """
if hasattr(torch.linalg, "qr"): if hasattr(torch, "linalg") and hasattr(torch.linalg, "qr"):
# PyTorch version >= 1.9 # PyTorch version >= 1.9
return torch.linalg.qr(A) return torch.linalg.qr(A)
return torch.qr(A) return torch.qr(A)

View File

@ -6,7 +6,7 @@
import math import math
import warnings import warnings
from typing import Optional, Sequence, Tuple, Union, List from typing import List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -259,8 +259,9 @@ class CamerasBase(TensorProperties):
# users might might have to implement the screen to NDC transform based # users might might have to implement the screen to NDC transform based
# on the definition of the camera parameters. # on the definition of the camera parameters.
# See PerspectiveCameras/OrthographicCameras for an example. # See PerspectiveCameras/OrthographicCameras for an example.
# We don't flip xy because we assume that world points are in PyTorch3D coodrinates # We don't flip xy because we assume that world points are in
# and thus conversion from screen to ndc is a mere scaling from image to [-1, 1] scale. # PyTorch3D coordinates, and thus conversion from screen to ndc
# is a mere scaling from image to [-1, 1] scale.
return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs) return get_screen_to_ndc_transform(self, with_xyflip=False, **kwargs)
def transform_points_ndc( def transform_points_ndc(

View File

@ -551,7 +551,6 @@ class PulsarPointsRenderer(nn.Module):
otherargs["bg_col"] = bg_col otherargs["bg_col"] = bg_col
# Go! # Go!
images.append( images.append(
torch.flipud(
self.renderer( self.renderer(
vert_pos=vert_pos, vert_pos=vert_pos,
vert_col=vert_col, vert_col=vert_col,
@ -561,7 +560,6 @@ class PulsarPointsRenderer(nn.Module):
max_depth=zfar, max_depth=zfar,
min_depth=znear, min_depth=znear,
**otherargs, **otherargs,
) ).flip(dims=[0])
)
) )
return torch.stack(images, dim=0) return torch.stack(images, dim=0)

View File

@ -140,8 +140,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
dim=-2, dim=-2,
) )
# clipping is not important here; if q_abs is small, the candidate won't be picked # We floor here at 0.1 but the exact level is not important; if q_abs is small,
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].clip(0.1)) # the candidate won't be picked.
# pyre-ignore [16]: `torch.Tensor` has no attribute `new_tensor`.
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(q_abs.new_tensor(0.1)))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign), # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator) # forall i; we pick the best-conditioned one (with the largest denominator)