From 6a365d203fff6b10c53e7483b33eac48d067055b Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 11 May 2020 12:55:30 -0700 Subject: [PATCH] Pointclouds, Meshes and Textures self-references Summary: Use `self.__class__` when creating new instances, to slightly accommodate inheritance. Reviewed By: nikhilaravi Differential Revision: D21504476 fbshipit-source-id: b4600d15462fc1985da95a4cf761c7d794cfb0bb --- pytorch3d/structures/meshes.py | 8 ++++---- pytorch3d/structures/pointclouds.py | 8 ++++---- pytorch3d/structures/textures.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index e7340443..a4431ce3 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -436,9 +436,9 @@ class Meshes(object): textures = None if self.textures is None else self.textures[index] if torch.is_tensor(verts) and torch.is_tensor(faces): - return Meshes(verts=[verts], faces=[faces], textures=textures) + return self.__class__(verts=[verts], faces=[faces], textures=textures) elif isinstance(verts, list) and isinstance(faces, list): - return Meshes(verts=verts, faces=faces, textures=textures) + return self.__class__(verts=verts, faces=faces, textures=textures) else: raise ValueError("(verts, faces) not defined correctly") @@ -1127,7 +1127,7 @@ class Meshes(object): faces_list = self.faces_list() new_verts_list = [v.clone() for v in verts_list] new_faces_list = [f.clone() for f in faces_list] - other = Meshes(verts=new_verts_list, faces=new_faces_list) + other = self.__class__(verts=new_verts_list, faces=new_faces_list) for k in self._INTERNAL_TENSORS: v = getattr(self, k) if torch.is_tensor(v): @@ -1370,7 +1370,7 @@ class Meshes(object): if self.textures is not None: tex = self.textures.extend(N) - return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex) + return self.__class__(verts=new_verts_list, faces=new_faces_list, textures=tex) def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True): diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 082507ca..687de83c 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -341,7 +341,7 @@ class Pointclouds(object): else: raise IndexError(index) - return Pointclouds(points=points, normals=normals, features=features) + return self.__class__(points=points, normals=normals, features=features) def isempty(self) -> bool: """ @@ -647,7 +647,7 @@ class Pointclouds(object): new_normals = self.normals_padded().clone() if features_padded is not None: new_features = self.features_padded().clone() - other = Pointclouds( + other = self.__class__( points=new_points, normals=new_normals, features=new_features ) for k in self._INTERNAL_TENSORS: @@ -920,7 +920,7 @@ class Pointclouds(object): new_features_list = [] for features in self.features_list(): new_features_list.extend(features.clone() for _ in range(N)) - return Pointclouds( + return self.__class__( points=new_points_list, normals=new_normals_list, features=new_features_list ) @@ -959,7 +959,7 @@ class Pointclouds(object): if new_features_padded is not None: check_shapes(new_features_padded, [self._N, self._P, self._C]) - new = Pointclouds( + new = self.__class__( points=new_points_padded, normals=new_normals_padded, features=new_features_padded, diff --git a/pytorch3d/structures/textures.py b/pytorch3d/structures/textures.py index fa32ae38..803167ba 100644 --- a/pytorch3d/structures/textures.py +++ b/pytorch3d/structures/textures.py @@ -129,7 +129,7 @@ class Textures(object): self._num_verts_per_mesh = None def clone(self): - other = Textures() + other = self.__class__() for k in dir(self): v = getattr(self, k) if torch.is_tensor(v): @@ -144,7 +144,7 @@ class Textures(object): return self def __getitem__(self, index): - other = Textures() + other = self.__class__() for key in dir(self): value = getattr(self, key) if torch.is_tensor(value): @@ -237,12 +237,12 @@ class Textures(object): new_verts_uvs = _extend_tensor(self._verts_uvs_padded, N) new_faces_uvs = _extend_tensor(self._faces_uvs_padded, N) new_maps = _extend_tensor(self._maps_padded, N) - return Textures( + return self.__class__( verts_uvs=new_verts_uvs, faces_uvs=new_faces_uvs, maps=new_maps ) elif self._verts_rgb_padded is not None: new_verts_rgb = _extend_tensor(self._verts_rgb_padded, N) - return Textures(verts_rgb=new_verts_rgb) + return self.__class__(verts_rgb=new_verts_rgb) else: msg = "Either vertex colors or texture maps are required." raise ValueError(msg)