Add type hints to MeshRenderer(WithFragments)

Reviewed By: bottler

Differential Revision: D36148049

fbshipit-source-id: 87ca3ea8d5b5a315418cc597b36fd0a1dffb1e00
This commit is contained in:
Krzysztof Chalupka 2022-05-06 14:48:26 -07:00 committed by Facebook GitHub Bot
parent ec9580a1d4
commit 2c64635daa

View File

@ -4,10 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
import torch.nn as nn
from ...structures.meshes import Meshes
from .rasterizer import MeshRasterizer
# A renderer class should be initialized with a
# function for rasterization and a function for shading.
@ -32,7 +36,7 @@ class MeshRenderer(nn.Module):
function.
"""
def __init__(self, rasterizer, shader) -> None:
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
@ -43,7 +47,7 @@ class MeshRenderer(nn.Module):
self.shader.to(device)
return self
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
def forward(self, meshes_world: Meshes, **kwargs) -> torch.Tensor:
"""
Render a batch of images from a batch of meshes by rasterizing and then
shading.
@ -76,7 +80,7 @@ class MeshRendererWithFragments(nn.Module):
depth = fragments.zbuf
"""
def __init__(self, rasterizer, shader) -> None:
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
@ -85,8 +89,11 @@ class MeshRendererWithFragments(nn.Module):
# Rasterizer and shader have submodules which are not of type nn.Module
self.rasterizer.to(device)
self.shader.to(device)
return self
def forward(self, meshes_world, **kwargs):
def forward(
self, meshes_world: Meshes, **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Render a batch of images from a batch of meshes by rasterizing and then
shading.