Annotate dunder functions

Summary: Annotate the (return type of the) following dunder functions across the codebase: `__init__()`, `__len__()`, `__getitem__()`

Reviewed By: nikhilaravi

Differential Revision: D29001801

fbshipit-source-id: 928d9e1c417ffe01ab8c0445311287786e997c7c
This commit is contained in:
Patrick Labatut 2021-06-24 15:18:19 -07:00 committed by Facebook GitHub Bot
parent 35855bf860
commit 64289a491d
35 changed files with 79 additions and 79 deletions

View File

@ -36,7 +36,7 @@ class ListDataset(Dataset):
A simple dataset made of a list of entries. A simple dataset made of a list of entries.
""" """
def __init__(self, entries: List): def __init__(self, entries: List) -> None:
""" """
Args: Args:
entries: The list of dataset entries. entries: The list of dataset entries.
@ -45,7 +45,7 @@ class ListDataset(Dataset):
def __len__( def __len__(
self, self,
): ) -> int:
return len(self._entries) return len(self._entries)
def __getitem__(self, index): def __getitem__(self, index):

View File

@ -22,7 +22,7 @@ class AverageMeter:
Tracks the exact history of the added values in every epoch. Tracks the exact history of the added values in every epoch.
""" """
def __init__(self): def __init__(self) -> None:
""" """
Initialize the structure with empty history and zero-ed moving average. Initialize the structure with empty history and zero-ed moving average.
""" """
@ -110,7 +110,7 @@ class Stats:
verbose: bool = False, verbose: bool = False,
epoch: int = -1, epoch: int = -1,
plot_file: Optional[str] = None, plot_file: Optional[str] = None,
): ) -> None:
""" """
Args: Args:
log_vars: The list of variable names to be logged. log_vars: The list of variable names to be logged.

View File

@ -64,7 +64,7 @@ class R2N2(ShapeNetBase): # pragma: no cover
voxels_rel_path: str = "ShapeNetVoxels", voxels_rel_path: str = "ShapeNetVoxels",
load_textures: bool = True, load_textures: bool = True,
texture_resolution: int = 4, texture_resolution: int = 4,
): ) -> None:
""" """
Store each object's synset id and models id the given directories. Store each object's synset id and models id the given directories.

View File

@ -437,7 +437,7 @@ class BlenderCamera(CamerasBase): # pragma: no cover
(which uses Blender for rendering the views for each model). (which uses Blender for rendering the views for each model).
""" """
def __init__(self, R=r, T=t, K=k, device: Device = "cpu"): def __init__(self, R=r, T=t, K=k, device: Device = "cpu") -> None:
""" """
Args: Args:
R: Rotation matrix of shape (N, 3, 3). R: Rotation matrix of shape (N, 3, 3).

View File

@ -31,7 +31,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
version: int = 1, version: int = 1,
load_textures: bool = True, load_textures: bool = True,
texture_resolution: int = 4, texture_resolution: int = 4,
): ) -> None:
""" """
Store each object's synset id and models id from data_dir. Store each object's synset id and models id from data_dir.

View File

@ -30,7 +30,7 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
and __getitem__ need to be implemented. and __getitem__ need to be implemented.
""" """
def __init__(self): def __init__(self) -> None:
""" """
Set up lists of synset_ids and model_ids. Set up lists of synset_ids and model_ids.
""" """
@ -44,7 +44,7 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
self.load_textures = True self.load_textures = True
self.texture_resolution = 4 self.texture_resolution = 4
def __len__(self): def __len__(self) -> int:
""" """
Return number of total models in the loaded dataset. Return number of total models in the loaded dataset.
""" """

View File

@ -187,7 +187,7 @@ def _make_node_transform(node: Dict[str, Any]) -> Transform3d:
class _GLTFLoader: class _GLTFLoader:
def __init__(self, stream: BinaryIO): def __init__(self, stream: BinaryIO) -> None:
self._json_data = None self._json_data = None
# Map from buffer index to (decoded) binary data # Map from buffer index to (decoded) binary data
self._binary_data = {} self._binary_data = {}
@ -539,7 +539,7 @@ class MeshGlbFormat(MeshFormatInterpreter):
used which does not match the semantics of the standard. used which does not match the semantics of the standard.
""" """
def __init__(self): def __init__(self) -> None:
self.known_suffixes = (".glb",) self.known_suffixes = (".glb",)
def read( def read(

View File

@ -291,7 +291,7 @@ def load_objs_as_meshes(
class MeshObjFormat(MeshFormatInterpreter): class MeshObjFormat(MeshFormatInterpreter):
def __init__(self): def __init__(self) -> None:
self.known_suffixes = (".obj",) self.known_suffixes = (".obj",)
def read( def read(

View File

@ -419,7 +419,7 @@ class MeshOffFormat(MeshFormatInterpreter):
""" """
def __init__(self): def __init__(self) -> None:
self.known_suffixes = (".off",) self.known_suffixes = (".off",)
def read( def read(

View File

@ -63,7 +63,7 @@ class IO:
self, self,
include_default_formats: bool = True, include_default_formats: bool = True,
path_manager: Optional[PathManager] = None, path_manager: Optional[PathManager] = None,
): ) -> None:
if path_manager is None: if path_manager is None:
self.path_manager = PathManager() self.path_manager = PathManager()
else: else:

View File

@ -66,7 +66,7 @@ class _PlyElementType:
self.name: (str) name of the element self.name: (str) name of the element
""" """
def __init__(self, name: str, count: int): def __init__(self, name: str, count: int) -> None:
self.name = name self.name = name
self.count = count self.count = count
self.properties: List[_Property] = [] self.properties: List[_Property] = []
@ -130,7 +130,7 @@ class _PlyElementType:
class _PlyHeader: class _PlyHeader:
def __init__(self, f): def __init__(self, f) -> None:
""" """
Load a header of a Ply file from a file-like object. Load a header of a Ply file from a file-like object.
Members: Members:
@ -1232,7 +1232,7 @@ def save_ply(
class MeshPlyFormat(MeshFormatInterpreter): class MeshPlyFormat(MeshFormatInterpreter):
def __init__(self): def __init__(self) -> None:
self.known_suffixes = (".ply",) self.known_suffixes = (".ply",)
def read( def read(
@ -1313,7 +1313,7 @@ class MeshPlyFormat(MeshFormatInterpreter):
class PointcloudPlyFormat(PointcloudFormatInterpreter): class PointcloudPlyFormat(PointcloudFormatInterpreter):
def __init__(self): def __init__(self) -> None:
self.known_suffixes = (".ply",) self.known_suffixes = (".ply",)
def read( def read(

View File

@ -21,7 +21,7 @@ class GraphConv(nn.Module):
output_dim: int, output_dim: int,
init: str = "normal", init: str = "normal",
directed: bool = False, directed: bool = False,
): ) -> None:
""" """
Args: Args:
input_dim: Number of input features per vertex. input_dim: Number of input features per vertex.

View File

@ -15,7 +15,7 @@ EPS = 0.00001
class Cube: class Cube:
def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1): def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1) -> None:
""" """
Initializes a cube given the bottom front left vertex coordinate Initializes a cube given the bottom front left vertex coordinate
and the cube spacing and the cube spacing

View File

@ -26,7 +26,7 @@ class SubdivideMeshes(nn.Module):
but different vertex positions. but different vertex positions.
""" """
def __init__(self, meshes=None): def __init__(self, meshes=None) -> None:
""" """
Args: Args:
meshes: Meshes object or None. If a meshes object is provided, meshes: Meshes object or None. If a meshes object is provided,

View File

@ -364,7 +364,7 @@ class FoVPerspectiveCameras(CamerasBase):
T=_T, T=_T,
K=None, K=None,
device: Device = "cpu", device: Device = "cpu",
): ) -> None:
""" """
Args: Args:
@ -848,7 +848,7 @@ class PerspectiveCameras(CamerasBase):
K=None, K=None,
device="cpu", device="cpu",
image_size=((-1, -1),), image_size=((-1, -1),),
): ) -> None:
""" """
Args: Args:
@ -1013,7 +1013,7 @@ class OrthographicCameras(CamerasBase):
K=None, K=None,
device="cpu", device="cpu",
image_size=((-1, -1),), image_size=((-1, -1),),
): ) -> None:
""" """
Args: Args:

View File

@ -49,7 +49,7 @@ class EmissionAbsorptionRaymarcher(torch.nn.Module):
elements along the ray direction. elements along the ray direction.
""" """
def __init__(self, surface_thickness: int = 1): def __init__(self, surface_thickness: int = 1) -> None:
""" """
Args: Args:
surface_thickness: Denotes the overlap between the absorption surface_thickness: Denotes the overlap between the absorption
@ -128,7 +128,7 @@ class AbsorptionOnlyRaymarcher(torch.nn.Module):
It then returns `opacities = 1 - total_transmission`. It then returns `opacities = 1 - total_transmission`.
""" """
def __init__(self): def __init__(self) -> None:
super().__init__() super().__init__()
def forward( def forward(

View File

@ -66,7 +66,7 @@ class GridRaysampler(torch.nn.Module):
n_pts_per_ray: int, n_pts_per_ray: int,
min_depth: float, min_depth: float,
max_depth: float, max_depth: float,
): ) -> None:
""" """
Args: Args:
min_x: The leftmost x-coordinate of each ray's source pixel's center. min_x: The leftmost x-coordinate of each ray's source pixel's center.
@ -150,7 +150,7 @@ class NDCGridRaysampler(GridRaysampler):
n_pts_per_ray: int, n_pts_per_ray: int,
min_depth: float, min_depth: float,
max_depth: float, max_depth: float,
): ) -> None:
""" """
Args: Args:
image_width: The horizontal size of the image grid. image_width: The horizontal size of the image grid.
@ -192,7 +192,7 @@ class MonteCarloRaysampler(torch.nn.Module):
n_pts_per_ray: int, n_pts_per_ray: int,
min_depth: float, min_depth: float,
max_depth: float, max_depth: float,
): ) -> None:
""" """
Args: Args:
min_x: The smallest x-coordinate of each ray's source pixel. min_x: The smallest x-coordinate of each ray's source pixel.

View File

@ -105,7 +105,7 @@ class ImplicitRenderer(torch.nn.Module):
``` ```
""" """
def __init__(self, raysampler: Callable, raymarcher: Callable): def __init__(self, raysampler: Callable, raymarcher: Callable) -> None:
""" """
Args: Args:
raysampler: A `Callable` that takes as input scene cameras raysampler: A `Callable` that takes as input scene cameras
@ -206,7 +206,7 @@ class VolumeRenderer(torch.nn.Module):
def __init__( def __init__(
self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear" self, raysampler: Callable, raymarcher: Callable, sample_mode: str = "bilinear"
): ) -> None:
""" """
Args: Args:
raysampler: A `Callable` that takes as input scene cameras raysampler: A `Callable` that takes as input scene cameras
@ -256,7 +256,7 @@ class VolumeSampler(torch.nn.Module):
at 3D points sampled along projection rays. at 3D points sampled along projection rays.
""" """
def __init__(self, volumes: Volumes, sample_mode: str = "bilinear"): def __init__(self, volumes: Volumes, sample_mode: str = "bilinear") -> None:
""" """
Args: Args:
volumes: An instance of the `Volumes` class representing a volumes: An instance of the `Volumes` class representing a

View File

@ -164,7 +164,7 @@ class DirectionalLights(TensorProperties):
specular_color=((0.2, 0.2, 0.2),), specular_color=((0.2, 0.2, 0.2),),
direction=((0, 1, 0),), direction=((0, 1, 0),),
device: Device = "cpu", device: Device = "cpu",
): ) -> None:
""" """
Args: Args:
ambient_color: RGB color of the ambient component. ambient_color: RGB color of the ambient component.
@ -225,7 +225,7 @@ class PointLights(TensorProperties):
specular_color=((0.2, 0.2, 0.2),), specular_color=((0.2, 0.2, 0.2),),
location=((0, 1, 0),), location=((0, 1, 0),),
device: Device = "cpu", device: Device = "cpu",
): ) -> None:
""" """
Args: Args:
ambient_color: RGB color of the ambient component ambient_color: RGB color of the ambient component
@ -294,7 +294,7 @@ class AmbientLights(TensorProperties):
not used in rendering. not used in rendering.
""" """
def __init__(self, *, ambient_color=None, device: Device = "cpu"): def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
""" """
If ambient_color is provided, it should be a sequence of If ambient_color is provided, it should be a sequence of
triples of floats. triples of floats.

View File

@ -24,7 +24,7 @@ class Materials(TensorProperties):
specular_color=((1, 1, 1),), specular_color=((1, 1, 1),),
shininess=64, shininess=64,
device: Device = "cpu", device: Device = "cpu",
): ) -> None:
""" """
Args: Args:
ambient_color: RGB ambient reflectivity of the material ambient_color: RGB ambient reflectivity of the material

View File

@ -84,7 +84,7 @@ class ClippedFaces:
barycentric_conversion: Optional[torch.Tensor] = None, barycentric_conversion: Optional[torch.Tensor] = None,
faces_clipped_to_conversion_idx: Optional[torch.Tensor] = None, faces_clipped_to_conversion_idx: Optional[torch.Tensor] = None,
clipped_faces_neighbor_idx: Optional[torch.Tensor] = None, clipped_faces_neighbor_idx: Optional[torch.Tensor] = None,
): ) -> None:
self.face_verts = face_verts self.face_verts = face_verts
self.mesh_to_face_first_idx = mesh_to_face_first_idx self.mesh_to_face_first_idx = mesh_to_face_first_idx
self.num_faces_per_mesh = num_faces_per_mesh self.num_faces_per_mesh = num_faces_per_mesh
@ -139,7 +139,7 @@ class ClipFrustum:
perspective_correct: bool = False, perspective_correct: bool = False,
cull: bool = True, cull: bool = True,
z_clip_value: Optional[float] = None, z_clip_value: Optional[float] = None,
): ) -> None:
self.left = left self.left = left
self.right = right self.right = right
self.top = top self.top = top

View File

@ -49,7 +49,7 @@ class RasterizationSettings:
cull_backfaces: bool = False, cull_backfaces: bool = False,
z_clip_value: Optional[float] = None, z_clip_value: Optional[float] = None,
cull_to_frustum: bool = False, cull_to_frustum: bool = False,
): ) -> None:
self.image_size = image_size self.image_size = image_size
self.blur_radius = blur_radius self.blur_radius = blur_radius
self.faces_per_pixel = faces_per_pixel self.faces_per_pixel = faces_per_pixel
@ -68,7 +68,7 @@ class MeshRasterizer(nn.Module):
Meshes. Meshes.
""" """
def __init__(self, cameras=None, raster_settings=None): def __init__(self, cameras=None, raster_settings=None) -> None:
""" """
Args: Args:
cameras: A cameras object which has a `transform_points` method cameras: A cameras object which has a `transform_points` method

View File

@ -32,7 +32,7 @@ class MeshRenderer(nn.Module):
function. function.
""" """
def __init__(self, rasterizer, shader): def __init__(self, rasterizer, shader) -> None:
super().__init__() super().__init__()
self.rasterizer = rasterizer self.rasterizer = rasterizer
self.shader = shader self.shader = shader
@ -76,7 +76,7 @@ class MeshRendererWithFragments(nn.Module):
depth = fragments.zbuf depth = fragments.zbuf
""" """
def __init__(self, rasterizer, shader): def __init__(self, rasterizer, shader) -> None:
super().__init__() super().__init__()
self.rasterizer = rasterizer self.rasterizer = rasterizer
self.shader = shader self.shader = shader

View File

@ -51,7 +51,7 @@ class HardPhongShader(nn.Module):
lights=None, lights=None,
materials=None, materials=None,
blend_params=None, blend_params=None,
): ) -> None:
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
self.materials = ( self.materials = (
@ -112,7 +112,7 @@ class SoftPhongShader(nn.Module):
lights=None, lights=None,
materials=None, materials=None,
blend_params=None, blend_params=None,
): ) -> None:
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
self.materials = ( self.materials = (
@ -178,7 +178,7 @@ class HardGouraudShader(nn.Module):
lights=None, lights=None,
materials=None, materials=None,
blend_params=None, blend_params=None,
): ) -> None:
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
self.materials = ( self.materials = (
@ -243,7 +243,7 @@ class SoftGouraudShader(nn.Module):
lights=None, lights=None,
materials=None, materials=None,
blend_params=None, blend_params=None,
): ) -> None:
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
self.materials = ( self.materials = (
@ -325,7 +325,7 @@ class HardFlatShader(nn.Module):
lights=None, lights=None,
materials=None, materials=None,
blend_params=None, blend_params=None,
): ) -> None:
super().__init__() super().__init__()
self.lights = lights if lights is not None else PointLights(device=device) self.lights = lights if lights is not None else PointLights(device=device)
self.materials = ( self.materials = (
@ -381,7 +381,7 @@ class SoftSilhouetteShader(nn.Module):
3D Reasoning', ICCV 2019 3D Reasoning', ICCV 2019
""" """
def __init__(self, blend_params=None): def __init__(self, blend_params=None) -> None:
super().__init__() super().__init__()
self.blend_params = blend_params if blend_params is not None else BlendParams() self.blend_params = blend_params if blend_params is not None else BlendParams()

View File

@ -262,7 +262,7 @@ class TexturesBase:
""" """
raise NotImplementedError() raise NotImplementedError()
def __getitem__(self, index): def __getitem__(self, index) -> "TexturesBase":
""" """
Each texture class should implement a method Each texture class should implement a method
to get the texture properties for the to get the texture properties for the
@ -321,7 +321,7 @@ def Textures(
class TexturesAtlas(TexturesBase): class TexturesAtlas(TexturesBase):
def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]): def __init__(self, atlas: Union[torch.Tensor, List[torch.Tensor]]) -> None:
""" """
A texture representation where each face has a square texture map. A texture representation where each face has a square texture map.
This is based on the implementation from SoftRasterizer [1]. This is based on the implementation from SoftRasterizer [1].
@ -420,7 +420,7 @@ class TexturesAtlas(TexturesBase):
tex._num_faces_per_mesh = num_faces tex._num_faces_per_mesh = num_faces
return tex return tex
def __getitem__(self, index): def __getitem__(self, index) -> "TexturesAtlas":
props = ["atlas_list", "_num_faces_per_mesh"] props = ["atlas_list", "_num_faces_per_mesh"]
new_props = self._getitem(index, props=props) new_props = self._getitem(index, props=props)
atlas = new_props["atlas_list"] atlas = new_props["atlas_list"]
@ -596,7 +596,7 @@ class TexturesUV(TexturesBase):
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
padding_mode: str = "border", padding_mode: str = "border",
align_corners: bool = True, align_corners: bool = True,
): ) -> None:
""" """
Textures are represented as a per mesh texture map and uv coordinates for each Textures are represented as a per mesh texture map and uv coordinates for each
vertex in each face. NOTE: this class only supports one texture map per mesh. vertex in each face. NOTE: this class only supports one texture map per mesh.
@ -786,7 +786,7 @@ class TexturesUV(TexturesBase):
tex.valid = self.valid.detach() tex.valid = self.valid.detach()
return tex return tex
def __getitem__(self, index): def __getitem__(self, index) -> "TexturesUV":
props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"] props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
new_props = self._getitem(index, props) new_props = self._getitem(index, props)
faces_uvs = new_props["faces_uvs_list"] faces_uvs = new_props["faces_uvs_list"]
@ -1257,7 +1257,7 @@ class TexturesVertex(TexturesBase):
def __init__( def __init__(
self, self,
verts_features: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]], verts_features: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
): ) -> None:
""" """
Batched texture representation where each vertex in a mesh Batched texture representation where each vertex in a mesh
has a C dimensional feature vector. has a C dimensional feature vector.

View File

@ -25,7 +25,7 @@ class AlphaCompositor(nn.Module):
def __init__( def __init__(
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
): ) -> None:
super().__init__() super().__init__()
self.background_color = background_color self.background_color = background_color
@ -47,7 +47,7 @@ class NormWeightedCompositor(nn.Module):
def __init__( def __init__(
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
): ) -> None:
super().__init__() super().__init__()
self.background_color = background_color self.background_color = background_color

View File

@ -326,7 +326,7 @@ class Renderer(torch.nn.Module):
background_normalized_depth: float = _C.EPS, background_normalized_depth: float = _C.EPS,
n_channels: int = 3, n_channels: int = 3,
n_track: int = 5, n_track: int = 5,
): ) -> None:
super(Renderer, self).__init__() super(Renderer, self).__init__()
# pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`. # pyre-fixme[16]: Module `pytorch3d` has no attribute `_C`.
self._renderer = _C.PulsarRenderer( self._renderer = _C.PulsarRenderer(

View File

@ -51,7 +51,7 @@ class PulsarPointsRenderer(nn.Module):
n_channels: int = 3, n_channels: int = 3,
max_num_spheres: int = int(1e6), # noqa: B008 max_num_spheres: int = int(1e6), # noqa: B008
**kwargs, **kwargs,
): ) -> None:
""" """
rasterizer (PointsRasterizer): An object encapsulating rasterization parameters. rasterizer (PointsRasterizer): An object encapsulating rasterization parameters.
compositor (ignored): Only keeping this for interface consistency. Default: None. compositor (ignored): Only keeping this for interface consistency. Default: None.

View File

@ -38,7 +38,7 @@ class PointsRasterizationSettings:
points_per_pixel: int = 8, points_per_pixel: int = 8,
bin_size: Optional[int] = None, bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None, max_points_per_bin: Optional[int] = None,
): ) -> None:
self.image_size = image_size self.image_size = image_size
self.radius = radius self.radius = radius
self.points_per_pixel = points_per_pixel self.points_per_pixel = points_per_pixel
@ -51,7 +51,7 @@ class PointsRasterizer(nn.Module):
This class implements methods for rasterizing a batch of pointclouds. This class implements methods for rasterizing a batch of pointclouds.
""" """
def __init__(self, cameras=None, raster_settings=None): def __init__(self, cameras=None, raster_settings=None) -> None:
""" """
cameras: A cameras object which has a `transform_points` method cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the which returns the transformed points after applying the

View File

@ -32,7 +32,7 @@ class PointsRenderer(nn.Module):
function. function.
""" """
def __init__(self, rasterizer, compositor): def __init__(self, rasterizer, compositor) -> None:
super().__init__() super().__init__()
self.rasterizer = rasterizer self.rasterizer = rasterizer
self.compositor = compositor self.compositor = compositor

View File

@ -25,7 +25,7 @@ class TensorAccessor(nn.Module):
and one element in the batch needs to be modified. and one element in the batch needs to be modified.
""" """
def __init__(self, class_object, index: Union[int, slice]): def __init__(self, class_object, index: Union[int, slice]) -> None:
""" """
Args: Args:
class_object: this should be an instance of a class which has class_object: this should be an instance of a class which has
@ -96,7 +96,7 @@ class TensorProperties(nn.Module):
def __init__( def __init__(
self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs self, dtype: torch.dtype = torch.float32, device: Device = "cpu", **kwargs
): ) -> None:
""" """
Args: Args:
dtype: data type to set for the inputs dtype: data type to set for the inputs
@ -143,7 +143,7 @@ class TensorProperties(nn.Module):
def isempty(self) -> bool: def isempty(self) -> bool:
return self._N == 0 return self._N == 0
def __getitem__(self, index: Union[int, slice]): def __getitem__(self, index: Union[int, slice]) -> TensorAccessor:
""" """
Args: Args:

View File

@ -219,7 +219,7 @@ class Meshes:
textures=None, textures=None,
*, *,
verts_normals=None, verts_normals=None,
): ) -> None:
""" """
Args: Args:
verts: verts:
@ -469,10 +469,10 @@ class Meshes:
else: else:
raise ValueError("verts_normals must be a list or tensor") raise ValueError("verts_normals must be a list or tensor")
def __len__(self): def __len__(self) -> int:
return self._N return self._N
def __getitem__(self, index): def __getitem__(self, index) -> "Meshes":
""" """
Args: Args:
index: Specifying the index of the mesh to retrieve. index: Specifying the index of the mesh to retrieve.
@ -493,7 +493,7 @@ class Meshes:
# NOTE consider converting index to cpu for efficiency # NOTE consider converting index to cpu for efficiency
if index.dtype == torch.bool: if index.dtype == torch.bool:
# advanced indexing on a single dimension # advanced indexing on a single dimension
index = index.nonzero() index = index.nonzero() # pyre-ignore
index = index.squeeze(1) if index.numel() > 0 else index index = index.squeeze(1) if index.numel() > 0 else index
index = index.tolist() index = index.tolist()
verts = [self.verts_list()[i] for i in index] verts = [self.verts_list()[i] for i in index]

View File

@ -108,7 +108,7 @@ class Pointclouds:
"equisized", "equisized",
] ]
def __init__(self, points, normals=None, features=None): def __init__(self, points, normals=None, features=None) -> None:
""" """
Args: Args:
points: points:
@ -306,10 +306,10 @@ class Pointclouds:
points in a cloud." points in a cloud."
) )
def __len__(self): def __len__(self) -> int:
return self._N return self._N
def __getitem__(self, index): def __getitem__(self, index) -> "Pointclouds":
""" """
Args: Args:
index: Specifying the index of the cloud to retrieve. index: Specifying the index of the cloud to retrieve.
@ -343,7 +343,7 @@ class Pointclouds:
# NOTE consider converting index to cpu for efficiency # NOTE consider converting index to cpu for efficiency
if index.dtype == torch.bool: if index.dtype == torch.bool:
# advanced indexing on a single dimension # advanced indexing on a single dimension
index = index.nonzero() index = index.nonzero() # pyre-ignore
index = index.squeeze(1) if index.numel() > 0 else index index = index.squeeze(1) if index.numel() > 0 else index
index = index.tolist() index = index.tolist()
points = [self.points_list()[i] for i in index] points = [self.points_list()[i] for i in index]

View File

@ -155,7 +155,7 @@ class Volumes:
features: Optional[_TensorBatch] = None, features: Optional[_TensorBatch] = None,
voxel_size: _VoxelSize = 1.0, voxel_size: _VoxelSize = 1.0,
volume_translation: _Translation = (0.0, 0.0, 0.0), volume_translation: _Translation = (0.0, 0.0, 0.0),
): ) -> None:
""" """
Args: Args:
**densities**: Batch of input feature volume occupancies of shape **densities**: Batch of input feature volume occupancies of shape

View File

@ -144,7 +144,7 @@ class Transform3d:
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Device = "cpu", device: Device = "cpu",
matrix: Optional[torch.Tensor] = None, matrix: Optional[torch.Tensor] = None,
): ) -> None:
""" """
Args: Args:
dtype: The data type of the transformation matrix. dtype: The data type of the transformation matrix.
@ -176,7 +176,7 @@ class Transform3d:
self.device = make_device(device) self.device = make_device(device)
self.dtype = dtype self.dtype = dtype
def __len__(self): def __len__(self) -> int:
return self.get_matrix().shape[0] return self.get_matrix().shape[0]
def __getitem__( def __getitem__(
@ -462,7 +462,7 @@ class Translate(Transform3d):
z=None, z=None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[Device] = None, device: Optional[Device] = None,
): ) -> None:
""" """
Create a new Transform3d representing 3D translations. Create a new Transform3d representing 3D translations.
@ -503,7 +503,7 @@ class Scale(Transform3d):
z=None, z=None,
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[Device] = None, device: Optional[Device] = None,
): ) -> None:
""" """
A Transform3d representing a scaling operation, with different scale A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis. factors along each coordinate axis.
@ -549,7 +549,7 @@ class Rotate(Transform3d):
dtype: torch.dtype = torch.float32, dtype: torch.dtype = torch.float32,
device: Optional[Device] = None, device: Optional[Device] = None,
orthogonal_tol: float = 1e-5, orthogonal_tol: float = 1e-5,
): ) -> None:
""" """
Create a new Transform3d representing 3D rotation using a rotation Create a new Transform3d representing 3D rotation using a rotation
matrix as the input. matrix as the input.
@ -589,7 +589,7 @@ class RotateAxisAngle(Rotate):
degrees: bool = True, degrees: bool = True,
dtype: torch.dtype = torch.float64, dtype: torch.dtype = torch.float64,
device: Optional[Device] = None, device: Optional[Device] = None,
): ) -> None:
""" """
Create a new Transform3d representing 3D rotation about an axis Create a new Transform3d representing 3D rotation about an axis
by an angle. by an angle.