allow cameras to be None in rasterizer initialization

Summary: Fix to enable a mesh/point rasterizer to be initialized without having to specify the camera.

Reviewed By: jcjohnson, gkioxari

Differential Revision: D21362359

fbshipit-source-id: 4f84ea18ad9f179c7b7c2289ebf9422a2f5e26de
This commit is contained in:
Nikhila Ravi
2020-05-05 22:29:38 -07:00
committed by Facebook GitHub Bot
parent 9c5ab57156
commit 17ca6ecd81
5 changed files with 151 additions and 26 deletions

View File

@@ -54,7 +54,7 @@ class MeshRasterizer(nn.Module):
Meshes.
"""
def __init__(self, cameras, raster_settings=None):
def __init__(self, cameras=None, raster_settings=None):
"""
Args:
cameras: A cameras object which has a `transform_points` method
@@ -88,6 +88,11 @@ class MeshRasterizer(nn.Module):
be moved into forward.
"""
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of MeshRasterizer"
raise ValueError(msg)
verts_world = meshes_world.verts_padded()
verts_world_packed = meshes_world.verts_packed()
verts_screen = cameras.transform_points(verts_world, **kwargs)

View File

@@ -10,7 +10,6 @@ from ..blending import (
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from ..cameras import OpenGLPerspectiveCameras
from ..lighting import PointLights
from ..materials import Materials
from .shading import flat_shading, gouraud_shading, phong_shading
@@ -46,13 +45,16 @@ class HardPhongShader(nn.Module):
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.cameras = cameras
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of HardPhongShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = phong_shading(
@@ -89,14 +91,16 @@ class SoftPhongShader(nn.Module):
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = phong_shading(
@@ -132,12 +136,14 @@ class HardGouraudShader(nn.Module):
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.cameras = cameras
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
pixel_colors = gouraud_shading(
@@ -174,13 +180,15 @@ class SoftGouraudShader(nn.Module):
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
pixel_colors = gouraud_shading(
@@ -219,14 +227,16 @@ class TexturedSoftPhongShader(nn.Module):
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_texture_map(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
texels = interpolate_texture_map(fragments, meshes)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
blend_params = kwargs.get("blend_params", self.blend_params)
@@ -262,13 +272,15 @@ class HardFlatShader(nn.Module):
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
)
self.cameras = cameras
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of SoftPhongShader"
raise ValueError(msg)
texels = interpolate_vertex_colors(fragments, meshes)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = flat_shading(

View File

@@ -48,7 +48,7 @@ class PointsRasterizer(nn.Module):
This class implements methods for rasterizing a batch of pointclouds.
"""
def __init__(self, cameras, raster_settings=None):
def __init__(self, cameras=None, raster_settings=None):
"""
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
@@ -80,6 +80,10 @@ class PointsRasterizer(nn.Module):
be moved into forward.
"""
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of PointsRasterizer"
raise ValueError(msg)
pts_world = point_clouds.points_padded()
pts_world_packed = point_clouds.points_packed()