mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add type hints to MeshRenderer(WithFragments)
Reviewed By: bottler Differential Revision: D36148049 fbshipit-source-id: 87ca3ea8d5b5a315418cc597b36fd0a1dffb1e00
This commit is contained in:
parent
ec9580a1d4
commit
2c64635daa
@ -4,10 +4,14 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...structures.meshes import Meshes
|
||||
from .rasterizer import MeshRasterizer
|
||||
|
||||
|
||||
# A renderer class should be initialized with a
|
||||
# function for rasterization and a function for shading.
|
||||
@ -32,7 +36,7 @@ class MeshRenderer(nn.Module):
|
||||
function.
|
||||
"""
|
||||
|
||||
def __init__(self, rasterizer, shader) -> None:
|
||||
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
|
||||
super().__init__()
|
||||
self.rasterizer = rasterizer
|
||||
self.shader = shader
|
||||
@ -43,7 +47,7 @@ class MeshRenderer(nn.Module):
|
||||
self.shader.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
||||
def forward(self, meshes_world: Meshes, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Render a batch of images from a batch of meshes by rasterizing and then
|
||||
shading.
|
||||
@ -76,7 +80,7 @@ class MeshRendererWithFragments(nn.Module):
|
||||
depth = fragments.zbuf
|
||||
"""
|
||||
|
||||
def __init__(self, rasterizer, shader) -> None:
|
||||
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
|
||||
super().__init__()
|
||||
self.rasterizer = rasterizer
|
||||
self.shader = shader
|
||||
@ -85,8 +89,11 @@ class MeshRendererWithFragments(nn.Module):
|
||||
# Rasterizer and shader have submodules which are not of type nn.Module
|
||||
self.rasterizer.to(device)
|
||||
self.shader.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, meshes_world, **kwargs):
|
||||
def forward(
|
||||
self, meshes_world: Meshes, **kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Render a batch of images from a batch of meshes by rasterizing and then
|
||||
shading.
|
||||
|
Loading…
x
Reference in New Issue
Block a user