From 2c64635daa2aa728f35ed4abe41c6942ae8c0d8b Mon Sep 17 00:00:00 2001 From: Krzysztof Chalupka Date: Fri, 6 May 2022 14:48:26 -0700 Subject: [PATCH] Add type hints to MeshRenderer(WithFragments) Reviewed By: bottler Differential Revision: D36148049 fbshipit-source-id: 87ca3ea8d5b5a315418cc597b36fd0a1dffb1e00 --- pytorch3d/renderer/mesh/renderer.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/pytorch3d/renderer/mesh/renderer.py b/pytorch3d/renderer/mesh/renderer.py index e374a0cf..f44c1bd3 100644 --- a/pytorch3d/renderer/mesh/renderer.py +++ b/pytorch3d/renderer/mesh/renderer.py @@ -4,10 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Tuple import torch import torch.nn as nn +from ...structures.meshes import Meshes +from .rasterizer import MeshRasterizer + # A renderer class should be initialized with a # function for rasterization and a function for shading. @@ -32,7 +36,7 @@ class MeshRenderer(nn.Module): function. """ - def __init__(self, rasterizer, shader) -> None: + def __init__(self, rasterizer: MeshRasterizer, shader) -> None: super().__init__() self.rasterizer = rasterizer self.shader = shader @@ -43,7 +47,7 @@ class MeshRenderer(nn.Module): self.shader.to(device) return self - def forward(self, meshes_world, **kwargs) -> torch.Tensor: + def forward(self, meshes_world: Meshes, **kwargs) -> torch.Tensor: """ Render a batch of images from a batch of meshes by rasterizing and then shading. @@ -76,7 +80,7 @@ class MeshRendererWithFragments(nn.Module): depth = fragments.zbuf """ - def __init__(self, rasterizer, shader) -> None: + def __init__(self, rasterizer: MeshRasterizer, shader) -> None: super().__init__() self.rasterizer = rasterizer self.shader = shader @@ -85,8 +89,11 @@ class MeshRendererWithFragments(nn.Module): # Rasterizer and shader have submodules which are not of type nn.Module self.rasterizer.to(device) self.shader.to(device) + return self - def forward(self, meshes_world, **kwargs): + def forward( + self, meshes_world: Meshes, **kwargs + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Render a batch of images from a batch of meshes by rasterizing and then shading.