diff --git a/pytorch3d/renderer/lighting.py b/pytorch3d/renderer/lighting.py index fbcd1aec..99fe4548 100644 --- a/pytorch3d/renderer/lighting.py +++ b/pytorch3d/renderer/lighting.py @@ -185,7 +185,7 @@ class DirectionalLights(TensorProperties): raise ValueError(msg % repr(self.direction.shape)) def clone(self): - other = DirectionalLights(device=self.device) + other = self.__class__(device=self.device) return super().clone(other) def diffuse(self, normals, points=None) -> torch.Tensor: @@ -244,7 +244,7 @@ class PointLights(TensorProperties): raise ValueError(msg % repr(self.location.shape)) def clone(self): - other = PointLights(device=self.device) + other = self.__class__(device=self.device) return super().clone(other) def diffuse(self, normals, points) -> torch.Tensor: