mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Add type hints to MeshRenderer(WithFragments)
Reviewed By: bottler Differential Revision: D36148049 fbshipit-source-id: 87ca3ea8d5b5a315418cc597b36fd0a1dffb1e00
This commit is contained in:
parent
ec9580a1d4
commit
2c64635daa
@ -4,10 +4,14 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from ...structures.meshes import Meshes
|
||||||
|
from .rasterizer import MeshRasterizer
|
||||||
|
|
||||||
|
|
||||||
# A renderer class should be initialized with a
|
# A renderer class should be initialized with a
|
||||||
# function for rasterization and a function for shading.
|
# function for rasterization and a function for shading.
|
||||||
@ -32,7 +36,7 @@ class MeshRenderer(nn.Module):
|
|||||||
function.
|
function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, rasterizer, shader) -> None:
|
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rasterizer = rasterizer
|
self.rasterizer = rasterizer
|
||||||
self.shader = shader
|
self.shader = shader
|
||||||
@ -43,7 +47,7 @@ class MeshRenderer(nn.Module):
|
|||||||
self.shader.to(device)
|
self.shader.to(device)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
def forward(self, meshes_world: Meshes, **kwargs) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Render a batch of images from a batch of meshes by rasterizing and then
|
Render a batch of images from a batch of meshes by rasterizing and then
|
||||||
shading.
|
shading.
|
||||||
@ -76,7 +80,7 @@ class MeshRendererWithFragments(nn.Module):
|
|||||||
depth = fragments.zbuf
|
depth = fragments.zbuf
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, rasterizer, shader) -> None:
|
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rasterizer = rasterizer
|
self.rasterizer = rasterizer
|
||||||
self.shader = shader
|
self.shader = shader
|
||||||
@ -85,8 +89,11 @@ class MeshRendererWithFragments(nn.Module):
|
|||||||
# Rasterizer and shader have submodules which are not of type nn.Module
|
# Rasterizer and shader have submodules which are not of type nn.Module
|
||||||
self.rasterizer.to(device)
|
self.rasterizer.to(device)
|
||||||
self.shader.to(device)
|
self.shader.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, meshes_world, **kwargs):
|
def forward(
|
||||||
|
self, meshes_world: Meshes, **kwargs
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Render a batch of images from a batch of meshes by rasterizing and then
|
Render a batch of images from a batch of meshes by rasterizing and then
|
||||||
shading.
|
shading.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user