suppress errors in vision/fair/pytorch3d

Differential Revision: D27934268

fbshipit-source-id: 51185fa493451012a9b2fd37379897d60596f73b
This commit is contained in:
Pyre Bot Jr 2021-04-21 23:27:13 -07:00 committed by Facebook GitHub Bot
parent eb04a488c5
commit 04d318d88f
8 changed files with 44 additions and 51 deletions

View File

@ -443,7 +443,7 @@ class BlenderCamera(CamerasBase):
def get_projection_transform(self, **kwargs) -> Transform3d:
transform = Transform3d(device=self.device)
transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16]
transform._matrix = self.K.transpose(1, 2).contiguous()
return transform

View File

@ -97,7 +97,6 @@ def corresponding_cameras_alignment(
cameras_src_aligned: `cameras_src` after applying the alignment transform.
"""
# pyre-fixme[16]: `CamerasBase` has no attribute `R`.
if cameras_src.R.shape[0] != cameras_tgt.R.shape[0]:
raise ValueError(
"cameras_src and cameras_tgt have to contain the same number of cameras!"
@ -121,7 +120,6 @@ def corresponding_cameras_alignment(
torch.bmm(
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1), cameras_src.R
)[:, 0]
# pyre-fixme[16]: `CamerasBase` has no attribute `T`.
+ cameras_src.T * align_t_s
)
@ -169,7 +167,6 @@ def _align_camera_extrinsics(
R_A = (U V^T)^T
```
"""
# pyre-fixme[16]: `CamerasBase` has no attribute `R`.
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
U, _, V = torch.svd(RRcov)
align_t_R = V @ U.t()
@ -199,8 +196,15 @@ def _align_camera_extrinsics(
T_A = mean(B) - mean(A) * s_A
```
"""
# pyre-fixme[16]: `CamerasBase` has no attribute `T`.
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
Amu = A.mean(0, keepdim=True)
Bmu = B.mean(0, keepdim=True)

View File

@ -478,17 +478,17 @@ class FoVPerspectiveCameras(CamerasBase):
[0, 0, 1, 0],
]
"""
K = kwargs.get("K", self.K) # pyre-ignore[16]
K = kwargs.get("K", self.K)
if K is not None:
if K.shape != (self._N, 4, 4):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
K = self.compute_projection_matrix(
kwargs.get("znear", self.znear), # pyre-ignore[16]
kwargs.get("zfar", self.zfar), # pyre-ignore[16]
kwargs.get("fov", self.fov), # pyre-ignore[16]
kwargs.get("aspect_ratio", self.aspect_ratio), # pyre-ignore[16]
kwargs.get("znear", self.znear),
kwargs.get("zfar", self.zfar),
kwargs.get("fov", self.fov),
kwargs.get("aspect_ratio", self.aspect_ratio),
kwargs.get("degrees", self.degrees),
)
@ -702,20 +702,20 @@ class FoVOrthographicCameras(CamerasBase):
[0, 0, 0, 1],
]
"""
K = kwargs.get("K", self.K) # pyre-ignore[16]
K = kwargs.get("K", self.K)
if K is not None:
if K.shape != (self._N, 4, 4):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
K = self.compute_projection_matrix(
kwargs.get("znear", self.znear), # pyre-ignore[16]
kwargs.get("zfar", self.zfar), # pyre-ignore[16]
kwargs.get("max_x", self.max_x), # pyre-ignore[16]
kwargs.get("min_x", self.min_x), # pyre-ignore[16]
kwargs.get("max_y", self.max_y), # pyre-ignore[16]
kwargs.get("min_y", self.min_y), # pyre-ignore[16]
kwargs.get("scale_xyz", self.scale_xyz), # pyre-ignore[16]
kwargs.get("znear", self.znear),
kwargs.get("zfar", self.zfar),
kwargs.get("max_x", self.max_x),
kwargs.get("min_x", self.min_x),
kwargs.get("max_y", self.max_y),
kwargs.get("min_y", self.min_y),
kwargs.get("scale_xyz", self.scale_xyz),
)
transform = Transform3d(device=self.device)
@ -902,13 +902,12 @@ class PerspectiveCameras(CamerasBase):
[0, 0, 1, 0],
]
"""
K = kwargs.get("K", self.K) # pyre-ignore[16]
K = kwargs.get("K", self.K)
if K is not None:
if K.shape != (self._N, 4, 4):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
# pyre-ignore[16]
image_size = kwargs.get("image_size", self.image_size)
# if imwidth > 0, parameters are in screen space
image_size = image_size if image_size[0][0] > 0 else None
@ -916,8 +915,8 @@ class PerspectiveCameras(CamerasBase):
K = _get_sfm_calibration_matrix(
self._N,
self.device,
kwargs.get("focal_length", self.focal_length), # pyre-ignore[16]
kwargs.get("principal_point", self.principal_point), # pyre-ignore[16]
kwargs.get("focal_length", self.focal_length),
kwargs.get("principal_point", self.principal_point),
orthographic=False,
image_size=image_size,
)
@ -1067,13 +1066,12 @@ class OrthographicCameras(CamerasBase):
[0, 0, 0, 1],
]
"""
K = kwargs.get("K", self.K) # pyre-ignore[16]
K = kwargs.get("K", self.K)
if K is not None:
if K.shape != (self._N, 4, 4):
msg = "Expected K to have shape of (%r, 4, 4)"
raise ValueError(msg % (self._N))
else:
# pyre-ignore[16]
image_size = kwargs.get("image_size", self.image_size)
# if imwidth > 0, parameters are in screen space
image_size = image_size if image_size[0][0] > 0 else None
@ -1081,8 +1079,8 @@ class OrthographicCameras(CamerasBase):
K = _get_sfm_calibration_matrix(
self._N,
self.device,
kwargs.get("focal_length", self.focal_length), # pyre-ignore[16]
kwargs.get("principal_point", self.principal_point), # pyre-ignore[16]
kwargs.get("focal_length", self.focal_length),
kwargs.get("principal_point", self.principal_point),
orthographic=True,
image_size=image_size,
)

View File

@ -113,7 +113,7 @@ class GridRaysampler(torch.nn.Module):
containing the 2D image coordinates of each ray.
"""
batch_size = cameras.R.shape[0] # pyre-ignore
batch_size = cameras.R.shape[0]
device = cameras.device
@ -229,7 +229,7 @@ class MonteCarloRaysampler(torch.nn.Module):
containing the 2D image coordinates of each ray.
"""
batch_size = cameras.R.shape[0] # pyre-ignore
batch_size = cameras.R.shape[0]
device = cameras.device

View File

@ -183,7 +183,6 @@ class DirectionalLights(TensorProperties):
direction=direction,
)
_validate_light_properties(self)
# pyre-fixme[16]: `DirectionalLights` has no attribute `direction`.
if self.direction.shape[-1] != 3:
msg = "Expected direction to have shape (N, 3); got %r"
raise ValueError(msg % repr(self.direction.shape))
@ -198,9 +197,7 @@ class DirectionalLights(TensorProperties):
# need to know the light type.
return diffuse(
normals=normals,
# pyre-fixme[16]: `DirectionalLights` has no attribute `diffuse_color`.
color=self.diffuse_color,
# pyre-fixme[16]: `DirectionalLights` has no attribute `direction`.
direction=self.direction,
)
@ -208,9 +205,7 @@ class DirectionalLights(TensorProperties):
return specular(
points=points,
normals=normals,
# pyre-fixme[16]: `DirectionalLights` has no attribute `specular_color`.
color=self.specular_color,
# pyre-fixme[16]: `DirectionalLights` has no attribute `direction`.
direction=self.direction,
camera_position=camera_position,
shininess=shininess,
@ -249,7 +244,6 @@ class PointLights(TensorProperties):
location=location,
)
_validate_light_properties(self)
# pyre-fixme[16]: `PointLights` has no attribute `location`.
if self.location.shape[-1] != 3:
msg = "Expected location to have shape (N, 3); got %r"
raise ValueError(msg % repr(self.location.shape))
@ -259,18 +253,14 @@ class PointLights(TensorProperties):
return super().clone(other)
def diffuse(self, normals, points) -> torch.Tensor:
# pyre-fixme[16]: `PointLights` has no attribute `location`.
direction = self.location - points
# pyre-fixme[16]: `PointLights` has no attribute `diffuse_color`.
return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
# pyre-fixme[16]: `PointLights` has no attribute `location`.
direction = self.location - points
return specular(
points=points,
normals=normals,
# pyre-fixme[16]: `PointLights` has no attribute `specular_color`.
color=self.specular_color,
direction=direction,
camera_position=camera_position,

View File

@ -632,7 +632,6 @@ class Renderer(torch.nn.Module):
enabled.
"""
# The device tracker is registered as buffer.
# pyre-fixme[16]: `Renderer` has no attribute `device_tracker`.
self._renderer.device_tracker = self.device_tracker
(
pos_vec,

View File

@ -189,17 +189,11 @@ class PulsarPointsRenderer(nn.Module):
if orthogonal_projection:
focal_length = torch.zeros((1,), dtype=torch.float32)
if isinstance(cameras, FoVOrthographicCameras):
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `znear`.
znear = kwargs.get("znear", cameras.znear)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `zfar`.
zfar = kwargs.get("zfar", cameras.zfar)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_y`.
max_y = kwargs.get("max_y", cameras.max_y)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_y`.
min_y = kwargs.get("min_y", cameras.min_y)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `max_x`.
max_x = kwargs.get("max_x", cameras.max_x)[cloud_idx]
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `min_x`.
min_x = kwargs.get("min_x", cameras.min_x)[cloud_idx]
if max_y != -min_y:
raise ValueError(
@ -212,9 +206,7 @@ class PulsarPointsRenderer(nn.Module):
f"Max is {max_x} and min is {min_x}."
)
if not torch.all(
# pyre-fixme[16]: `FoVOrthographicCameras` has no attribute `scale_xyz`.
kwargs.get("scale_xyz", cameras.scale_xyz)[cloud_idx]
== 1.0
kwargs.get("scale_xyz", cameras.scale_xyz)[cloud_idx] == 1.0
):
raise ValueError(
"The orthographic camera scale must be ((1.0, 1.0, 1.0),). "
@ -297,7 +289,6 @@ class PulsarPointsRenderer(nn.Module):
torch.zeros((1,), dtype=torch.float32),
)
else:
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `focal_length`.
focal_length_conf = kwargs.get("focal_length", cameras.focal_length)[
cloud_idx
]
@ -336,7 +327,6 @@ class PulsarPointsRenderer(nn.Module):
)
sensor_width = focal_length / focal_length_px * 2.0
principal_point_x = (
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
kwargs.get("principal_point", cameras.principal_point)[cloud_idx][0]
* 0.5
* self.renderer._renderer.width

View File

@ -470,10 +470,22 @@ def _add_struct_from_batch(
struct = None
if isinstance(batched_struct, CamerasBase):
# we can't index directly into camera batches
R, T = batched_struct.R, batched_struct.T # pyre-ignore[16]
R, T = batched_struct.R, batched_struct.T
# pyre-fixme[6]: Expected `Sized` for 1st param but got `Union[torch.Tensor,
# torch.nn.Module]`.
r_idx = min(scene_num, len(R) - 1)
# pyre-fixme[6]: Expected `Sized` for 1st param but got `Union[torch.Tensor,
# torch.nn.Module]`.
t_idx = min(scene_num, len(T) - 1)
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
R = R[r_idx].unsqueeze(0)
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
T = T[t_idx].unsqueeze(0)
struct = CamerasBase(device=batched_struct.device, R=R, T=T)
else: # batched meshes and pointclouds are indexable