Annotate dunder functions

Summary: Annotate the (return type of the) following dunder functions across the codebase: `__init__()`, `__len__()`, `__getitem__()`

Reviewed By: nikhilaravi

Differential Revision: D29001801

fbshipit-source-id: 928d9e1c417ffe01ab8c0445311287786e997c7c
This commit is contained in:
Patrick Labatut 2021-06-24 15:18:19 -07:00 committed by Facebook GitHub Bot
parent 35855bf860
commit 64289a491d
35 changed files with 79 additions and 79 deletions

View File

@ -36,7 +36,7 @@ class ListDataset(Dataset):
A simple dataset made of a list of entries.
"""
def __init__(self, entries: List):
def __init__(self, entries: List) -> None:
"""
Args:
entries: The list of dataset entries.
@ -45,7 +45,7 @@ class ListDataset(Dataset):
def __len__(
self,
):
) -> int:
return len(self._entries)
def __getitem__(self, index):

View File

@ -22,7 +22,7 @@ class AverageMeter:
Tracks the exact history of the added values in every epoch.
"""
def __init__(self):
def __init__(self) -> None:
"""
Initialize the structure with empty history and zero-ed moving average.
"""
@ -110,7 +110,7 @@ class Stats:
verbose: bool = False,
epoch: int = -1,
plot_file: Optional[str] = None,
):
) -> None:
"""
Args:
log_vars: The list of variable names to be logged.

View File

@ -64,7 +64,7 @@ class R2N2(ShapeNetBase): # pragma: no cover
voxels_rel_path: str = "ShapeNetVoxels",
load_textures: bool = True,
texture_resolution: int = 4,
):
) -> None:
"""
Store each object's synset id and models id the given directories.

View File

@ -437,7 +437,7 @@ class BlenderCamera(CamerasBase): # pragma: no cover
(which uses Blender for rendering the views for each model).
"""
def __init__(self, R=r, T=t, K=k, device: Device = "cpu"):
def __init__(self, R=r, T=t, K=k, device: Device = "cpu") -> None:
"""
Args:
R: Rotation matrix of shape (N, 3, 3).

View File

@ -31,7 +31,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
version: int = 1,
load_textures: bool = True,
texture_resolution: int = 4,
):
) -> None:
"""
Store each object's synset id and models id from data_dir.

View File

@ -30,7 +30,7 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
and __getitem__ need to be implemented.
"""
def __init__(self):
def __init__(self) -> None:
"""
Set up lists of synset_ids and model_ids.
"""
@ -44,7 +44,7 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
self.load_textures = True
self.texture_resolution = 4
def __len__(self):
def __len__(self) -> int:
"""
Return number of total models in the loaded dataset.
"""

View File

@ -187,7 +187,7 @@ def _make_node_transform(node: Dict[str, Any]) -> Transform3d:
class _GLTFLoader:
def __init__(self, stream: BinaryIO):
def __init__(self, stream: BinaryIO) -> None:
self._json_data = None
# Map from buffer index to (decoded) binary data
self._binary_data = {}
@ -539,7 +539,7 @@ class MeshGlbFormat(MeshFormatInterpreter):
used which does not match the semantics of the standard.
"""
def __init__(self):
def __init__(self) -> None:
self.known_suffixes = (".glb",)
def read(

View File

@ -291,7 +291,7 @@ def load_objs_as_meshes(
class MeshObjFormat(MeshFormatInterpreter):
def __init__(self):
def __init__(self) -> None:
self.known_suffixes = (".obj",)
def read(

View File

@ -419,7 +419,7 @@ class MeshOffFormat(MeshFormatInterpreter):
"""
def __init__(self):
def __init__(self) -> None:
self.known_suffixes = (".off",)
def read(

View File

@ -63,7 +63,7 @@ class IO:
self,
include_default_formats: bool = True,
path_manager: Optional[PathManager] = None,
):
) -> None:
if path_manager is None:
self.path_manager = PathManager()
else:

View File

@ -66,7 +66,7 @@ class _PlyElementType:
self.name: (str) name of the element
"""
def __init__(self, name: str, count: int):
def __init__(self, name: str, count: int) -> None:
self.name = name
self.count = count
self.properties: List[_Property] = []
@ -130,7 +130,7 @@ class _PlyElementType:
class _PlyHeader:
def __init__(self, f):
def __init__(self, f) -> None:
"""
Load a header of a Ply file from a file-like object.
Members:
@ -1232,7 +1232,7 @@ def save_ply(
class MeshPlyFormat(MeshFormatInterpreter):
def __init__(self):
def __init__(self) -> None:
self.known_suffixes = (".ply",)
def read(
@ -1313,7 +1313,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
class PointcloudPlyFormat(PointcloudFormatInterpreter):
def __init__(self):
def __init__(self) -> None:
self.known_suffixes = (".ply",)
def read(

View File

@ -21,7 +21,7 @@ class GraphConv(nn.Module):
output_dim: int,
init: str = "normal",
directed: bool = False,
):
) -> None:
"""
Args:
input_dim: Number of input features per vertex.

View File

@ -15,7 +15,7 @@ EPS = 0.00001
class Cube:
def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1):
def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1) -> None:
"""
Initializes a cube given the bottom front left vertex coordinate
and the cube spacing

View File

@ -26,7 +26,7 @@ class SubdivideMeshes(nn.Module):
but different vertex positions.
"""
def __init__(self, meshes=None):
def __init__(self, meshes=None) -> None:
"""
Args:
meshes: Meshes object or None. If a meshes object is provided,

View File

@ -364,7 +364,7 @@ class FoVPerspectiveCameras(CamerasBase):
T=_T,
K=None,
device: Device = "cpu",
):
) -> None:
"""
Args:
@ -848,7 +848,7 @@ class PerspectiveCameras(CamerasBase):
K=None,
device="cpu",
image_size=((-1, -1),),
):
) -> None:
"""
Args:
@ -1013,7 +1013,7 @@ class OrthographicCameras(CamerasBase):
K=None,
device="cpu",
image_size=((-1, -1),),
):
) -> None:
"""
Args:

View File

@ -49,7 +49,7 @@ class EmissionAbsorptionRaymarcher(torch.nn.Module):
elements along the ray direction.
"""
def __init__(self, surface_thickness: int = 1):
def __init__(self, surface_thickness: int = 1) -> None:
"""
Args:
surface_thickness: Denotes the overlap between the absorption
@ -128,7 +128,7 @@ class AbsorptionOnlyRaymarcher(torch.nn.Module):
It then returns `opacities = 1 - total_transmission`.
"""
def __init__(self):
def __init__(self) -> None:
super().__init__()
def forward(

View File

@ -66,7 +66,7 @@ class GridRaysampler(torch.nn.Module):
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
):
) -> None:
"""
Args:
min_x: The leftmost x-coordinate of each ray's source pixel's center.
@ -150,7 +150,7 @@ class NDCGridRaysampler(GridRaysampler):
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
):
) -> None:
"""
Args:
image_width: The horizontal size of the image grid.
@ -192,7 +192,7 @@ class MonteCarloRaysampler(torch.nn.Module):
n_pts_per_ray: int,
min_depth: float,
max_depth: float,
):
) -> None:
"""
Args:
min_x: The smallest x-coordinate of each ray's source pixel.

View File

@ -105,7 +105,7 @@ class ImplicitRenderer(torch.nn.Module):
```
"""
def __init__(self, raysampler: Callable, raymarcher: Callable):
def __init__(self, raysampler: Callable, raymarcher: Callable) -> None:
"""
Args:
raysampler: A `Callable` that takes as input scene cameras
@ -206,7 +206,7 @@ class VolumeRenderer(torch.nn.Module):
def __init__(
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear"
):
) -> None:
"""
Args:
raysampler: A `Callable` that takes as input scene cameras
@ -256,7 +256,7 @@ class VolumeSampler(torch.nn.Module):
at 3D points sampled along projection rays.
"""
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"):
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear") -> None:
"""
Args:
volumes: An instance of the `Volumes` class representing a

View File

@ -164,7 +164,7 @@ class DirectionalLights(TensorProperties):
specular_color=((0.2, 0.2, 0.2),),
direction=((0, 1, 0),),
device: Device = "cpu",
):
) -> None:
"""
Args:
ambient_color: RGB color of the ambient component.
@ -225,7 +225,7 @@ class PointLights(TensorProperties):
specular_color=((0.2, 0.2, 0.2),),
location=((0, 1, 0),),
device: Device = "cpu",
):
) -> None:
"""
Args:
ambient_color: RGB color of the ambient component
@ -294,7 +294,7 @@ class AmbientLights(TensorProperties):
not used in rendering.
"""
def __init__(self, *, ambient_color=None, device: Device = "cpu"):
def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
"""
If ambient_color is provided, it should be a sequence of
triples of floats.

View File

@ -24,7 +24,7 @@ class Materials(TensorProperties):
specular_color=((1, 1, 1),),
shininess=64,
device: Device = "cpu",
):
) -> None:
"""
Args:
ambient_color: RGB ambient reflectivity of the material

View File

@ -84,7 +84,7 @@ class ClippedFaces:
barycentric_conversion: Optional[torch.Tensor] = None,
faces_clipped_to_conversion_idx: Optional[torch.Tensor] = None,
clipped_faces_neighbor_idx: Optional[torch.Tensor] = None,
):
) -> None:
self.face_verts = face_verts
self.mesh_to_face_first_idx = mesh_to_face_first_idx
self.num_faces_per_mesh = num_faces_per_mesh
@ -139,7 +139,7 @@ class ClipFrustum:
perspective_correct: bool = False,
cull: bool = True,
z_clip_value: Optional[float] = None,
):
) -> None:
self.left = left
self.right = right
self.top = top

View File

@ -49,7 +49,7 @@ class RasterizationSettings:
cull_backfaces: bool = False,
z_clip_value: Optional[float] = None,
cull_to_frustum: bool = False,
):
) -> None:
self.image_size = image_size
self.blur_radius = blur_radius
self.faces_per_pixel = faces_per_pixel
@ -68,7 +68,7 @@ class MeshRasterizer(nn.Module):
Meshes.
"""
def __init__(self, cameras=None, raster_settings=None):
def __init__(self, cameras=None, raster_settings=None) -> None:
"""
Args:
cameras: A cameras object which has a `transform_points` method

View File

@ -32,7 +32,7 @@ class MeshRenderer(nn.Module):
function.
"""
def __init__(self, rasterizer, shader):
def __init__(self, rasterizer, shader) -> None:
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
@ -76,7 +76,7 @@ class MeshRendererWithFragments(nn.Module):
depth = fragments.zbuf
"""
def __init__(self, rasterizer, shader):
def __init__(self, rasterizer, shader) -> None:
super().__init__()
self.rasterizer = rasterizer
self.shader = shader

View File

@ -51,7 +51,7 @@ class HardPhongShader(nn.Module):
lights=None,
materials=None,
blend_params=None,
):
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
@ -112,7 +112,7 @@ class SoftPhongShader(nn.Module):
lights=None,
materials=None,
blend_params=None,
):
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
@ -178,7 +178,7 @@ class HardGouraudShader(nn.Module):
lights=None,
materials=None,
blend_params=None,
):
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
@ -243,7 +243,7 @@ class SoftGouraudShader(nn.Module):
lights=None,
materials=None,
blend_params=None,
):
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
@ -325,7 +325,7 @@ class HardFlatShader(nn.Module):
lights=None,
materials=None,
blend_params=None,
):
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
@ -381,7 +381,7 @@ class SoftSilhouetteShader(nn.Module):
3D Reasoning', ICCV 2019
"""
def __init__(self, blend_params=None):
def __init__(self, blend_params=None) -> None:
super().__init__()
self.blend_params = blend_params if blend_params is not None else BlendParams()

View File

@ -262,7 +262,7 @@ class TexturesBase:
"""
raise NotImplementedError()
def __getitem__(self, index):
def __getitem__(self, index) -> "TexturesBase":
"""
Each texture class should implement a method
to get the texture properties for the
@ -321,7 +321,7 @@ def Textures(
class TexturesAtlas(TexturesBase):
def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]):
def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]) -> None:
"""
A texture representation where each face has a square texture map.
This is based on the implementation from SoftRasterizer [1].
@ -420,7 +420,7 @@ class TexturesAtlas(TexturesBase):
tex._num_faces_per_mesh = num_faces
return tex
def __getitem__(self, index):
def __getitem__(self, index) -> "TexturesAtlas":
props = ["atlas_list", "_num_faces_per_mesh"]
new_props = self._getitem(index, props=props)
atlas = new_props["atlas_list"]
@ -596,7 +596,7 @@ class TexturesUV(TexturesBase):
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
padding_mode: str = "border",
align_corners: bool = True,
):
) -> None:
"""
Textures are represented as a per mesh texture map and uv coordinates for each
vertex in each face. NOTE: this class only supports one texture map per mesh.
@ -786,7 +786,7 @@ class TexturesUV(TexturesBase):
tex.valid = self.valid.detach()
return tex
def __getitem__(self, index):
def __getitem__(self, index) -> "TexturesUV":
props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
new_props = self._getitem(index, props)
faces_uvs = new_props["faces_uvs_list"]
@ -1257,7 +1257,7 @@ class TexturesVertex(TexturesBase):
def __init__(
self,
verts_features: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
):
) -> None:
"""
Batched texture representation where each vertex in a mesh
has a C dimensional feature vector.

View File

@ -25,7 +25,7 @@ class AlphaCompositor(nn.Module):
def __init__(
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
):
) -> None:
super().__init__()
self.background_color = background_color
@ -47,7 +47,7 @@ class NormWeightedCompositor(nn.Module):
def __init__(
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
):
) -> None:
super().__init__()
self.background_color = background_color

View File

@ -326,7 +326,7 @@ class Renderer(torch.nn.Module):
background_normalized_depth: float = _C.EPS,
n_channels: int = 3,
n_track: int = 5,
):
) -> None:
super(Renderer, self).__init__()
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
self._renderer = _C.PulsarRenderer(

View File

@ -51,7 +51,7 @@ class PulsarPointsRenderer(nn.Module):
n_channels: int = 3,
max_num_spheres: int = int(1e6), # noqa: B008
**kwargs,
):
) -> None:
"""
rasterizer (PointsRasterizer): An object encapsulating rasterization parameters.
compositor (ignored): Only keeping this for interface consistency. Default: None.

View File

@ -38,7 +38,7 @@ class PointsRasterizationSettings:
points_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None,
):
) -> None:
self.image_size = image_size
self.radius = radius
self.points_per_pixel = points_per_pixel
@ -51,7 +51,7 @@ class PointsRasterizer(nn.Module):
This class implements methods for rasterizing a batch of pointclouds.
"""
def __init__(self, cameras=None, raster_settings=None):
def __init__(self, cameras=None, raster_settings=None) -> None:
"""
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the

View File

@ -32,7 +32,7 @@ class PointsRenderer(nn.Module):
function.
"""
def __init__(self, rasterizer, compositor):
def __init__(self, rasterizer, compositor) -> None:
super().__init__()
self.rasterizer = rasterizer
self.compositor = compositor

View File

@ -25,7 +25,7 @@ class TensorAccessor(nn.Module):
and one element in the batch needs to be modified.
"""
def __init__(self, class_object, index: Union[int, slice]):
def __init__(self, class_object, index: Union[int, slice]) -> None:
"""
Args:
class_object: this should be an instance of a class which has
@ -96,7 +96,7 @@ class TensorProperties(nn.Module):
def __init__(
self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs
):
) -> None:
"""
Args:
dtype: data type to set for the inputs
@ -143,7 +143,7 @@ class TensorProperties(nn.Module):
def isempty(self) -> bool:
return self._N == 0
def __getitem__(self, index: Union[int, slice]):
def __getitem__(self, index: Union[int, slice]) -> TensorAccessor:
"""
Args:

View File

@ -219,7 +219,7 @@ class Meshes:
textures=None,
*,
verts_normals=None,
):
) -> None:
"""
Args:
verts:
@ -469,10 +469,10 @@ class Meshes:
else:
raise ValueError("verts_normals must be a list or tensor")
def __len__(self):
def __len__(self) -> int:
return self._N
def __getitem__(self, index):
def __getitem__(self, index) -> "Meshes":
"""
Args:
index: Specifying the index of the mesh to retrieve.
@ -493,7 +493,7 @@ class Meshes:
# NOTE consider converting index to cpu for efficiency
if index.dtype == torch.bool:
# advanced indexing on a single dimension
index = index.nonzero()
index = index.nonzero() # pyre-ignore
index = index.squeeze(1) if index.numel() > 0 else index
index = index.tolist()
verts = [self.verts_list()[i] for i in index]

View File

@ -108,7 +108,7 @@ class Pointclouds:
"equisized",
]
def __init__(self, points, normals=None, features=None):
def __init__(self, points, normals=None, features=None) -> None:
"""
Args:
points:
@ -306,10 +306,10 @@ class Pointclouds:
points in a cloud."
)
def __len__(self):
def __len__(self) -> int:
return self._N
def __getitem__(self, index):
def __getitem__(self, index) -> "Pointclouds":
"""
Args:
index: Specifying the index of the cloud to retrieve.
@ -343,7 +343,7 @@ class Pointclouds:
# NOTE consider converting index to cpu for efficiency
if index.dtype == torch.bool:
# advanced indexing on a single dimension
index = index.nonzero()
index = index.nonzero() # pyre-ignore
index = index.squeeze(1) if index.numel() > 0 else index
index = index.tolist()
points = [self.points_list()[i] for i in index]

View File

@ -155,7 +155,7 @@ class Volumes:
features: Optional[_TensorBatch] = None,
voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0),
):
) -> None:
"""
Args:
**densities**: Batch of input feature volume occupancies of shape

View File

@ -144,7 +144,7 @@ class Transform3d:
dtype: torch.dtype = torch.float32,
device: Device = "cpu",
matrix: Optional[torch.Tensor] = None,
):
) -> None:
"""
Args:
dtype: The data type of the transformation matrix.
@ -176,7 +176,7 @@ class Transform3d:
self.device = make_device(device)
self.dtype = dtype
def __len__(self):
def __len__(self) -> int:
return self.get_matrix().shape[0]
def __getitem__(
@ -462,7 +462,7 @@ class Translate(Transform3d):
z=None,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
):
) -> None:
"""
Create a new Transform3d representing 3D translations.
@ -503,7 +503,7 @@ class Scale(Transform3d):
z=None,
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
):
) -> None:
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
@ -549,7 +549,7 @@ class Rotate(Transform3d):
dtype: torch.dtype = torch.float32,
device: Optional[Device] = None,
orthogonal_tol: float = 1e-5,
):
) -> None:
"""
Create a new Transform3d representing 3D rotation using a rotation
matrix as the input.
@ -589,7 +589,7 @@ class RotateAxisAngle(Rotate):
degrees: bool = True,
dtype: torch.dtype = torch.float64,
device: Optional[Device] = None,
):
) -> None:
"""
Create a new Transform3d representing 3D rotation about an axis
by an angle.