Add type hints to MeshRenderer(WithFragments)

Reviewed By: bottler

Differential Revision: D36148049

fbshipit-source-id: 87ca3ea8d5b5a315418cc597b36fd0a1dffb1e00
This commit is contained in:
Krzysztof Chalupka 2022-05-06 14:48:26 -07:00 committed by Facebook GitHub Bot
parent ec9580a1d4
commit 2c64635daa

View File

@ -4,10 +4,14 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...structures.meshes import Meshes
from .rasterizer import MeshRasterizer
# A renderer class should be initialized with a # A renderer class should be initialized with a
# function for rasterization and a function for shading. # function for rasterization and a function for shading.
@ -32,7 +36,7 @@ class MeshRenderer(nn.Module):
function. function.
""" """
def __init__(self, rasterizer, shader) -> None: def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
super().__init__() super().__init__()
self.rasterizer = rasterizer self.rasterizer = rasterizer
self.shader = shader self.shader = shader
@ -43,7 +47,7 @@ class MeshRenderer(nn.Module):
self.shader.to(device) self.shader.to(device)
return self return self
def forward(self, meshes_world, **kwargs) -> torch.Tensor: def forward(self, meshes_world: Meshes, **kwargs) -> torch.Tensor:
""" """
Render a batch of images from a batch of meshes by rasterizing and then Render a batch of images from a batch of meshes by rasterizing and then
shading. shading.
@ -76,7 +80,7 @@ class MeshRendererWithFragments(nn.Module):
depth = fragments.zbuf depth = fragments.zbuf
""" """
def __init__(self, rasterizer, shader) -> None: def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
super().__init__() super().__init__()
self.rasterizer = rasterizer self.rasterizer = rasterizer
self.shader = shader self.shader = shader
@ -85,8 +89,11 @@ class MeshRendererWithFragments(nn.Module):
# Rasterizer and shader have submodules which are not of type nn.Module # Rasterizer and shader have submodules which are not of type nn.Module
self.rasterizer.to(device) self.rasterizer.to(device)
self.shader.to(device) self.shader.to(device)
return self
def forward(self, meshes_world, **kwargs): def forward(
self, meshes_world: Meshes, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Render a batch of images from a batch of meshes by rasterizing and then Render a batch of images from a batch of meshes by rasterizing and then
shading. shading.