mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 12:22:49 +08:00
allow cameras to be None in rasterizer initialization
Summary: Fix to enable a mesh/point rasterizer to be initialized without having to specify the camera. Reviewed By: jcjohnson, gkioxari Differential Revision: D21362359 fbshipit-source-id: 4f84ea18ad9f179c7b7c2289ebf9422a2f5e26de
This commit is contained in:
parent
9c5ab57156
commit
17ca6ecd81
@ -79,6 +79,7 @@
|
|||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {},
|
"colab": {},
|
||||||
"colab_type": "code",
|
"colab_type": "code",
|
||||||
|
"collapsed": true,
|
||||||
"id": "w9mH5iVprQdZ"
|
"id": "w9mH5iVprQdZ"
|
||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -260,7 +261,7 @@
|
|||||||
" cameras=cameras, \n",
|
" cameras=cameras, \n",
|
||||||
" raster_settings=raster_settings\n",
|
" raster_settings=raster_settings\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" shader=HardPhongShader(device=device, lights=lights)\n",
|
" shader=HardPhongShader(device=device, cameras=cameras, lights=lights)\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -54,7 +54,7 @@ class MeshRasterizer(nn.Module):
|
|||||||
Meshes.
|
Meshes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cameras, raster_settings=None):
|
def __init__(self, cameras=None, raster_settings=None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
cameras: A cameras object which has a `transform_points` method
|
cameras: A cameras object which has a `transform_points` method
|
||||||
@ -88,6 +88,11 @@ class MeshRasterizer(nn.Module):
|
|||||||
be moved into forward.
|
be moved into forward.
|
||||||
"""
|
"""
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
if cameras is None:
|
||||||
|
msg = "Cameras must be specified either at initialization \
|
||||||
|
or in the forward pass of MeshRasterizer"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
verts_world = meshes_world.verts_padded()
|
verts_world = meshes_world.verts_padded()
|
||||||
verts_world_packed = meshes_world.verts_packed()
|
verts_world_packed = meshes_world.verts_packed()
|
||||||
verts_screen = cameras.transform_points(verts_world, **kwargs)
|
verts_screen = cameras.transform_points(verts_world, **kwargs)
|
||||||
|
@ -10,7 +10,6 @@ from ..blending import (
|
|||||||
sigmoid_alpha_blend,
|
sigmoid_alpha_blend,
|
||||||
softmax_rgb_blend,
|
softmax_rgb_blend,
|
||||||
)
|
)
|
||||||
from ..cameras import OpenGLPerspectiveCameras
|
|
||||||
from ..lighting import PointLights
|
from ..lighting import PointLights
|
||||||
from ..materials import Materials
|
from ..materials import Materials
|
||||||
from .shading import flat_shading, gouraud_shading, phong_shading
|
from .shading import flat_shading, gouraud_shading, phong_shading
|
||||||
@ -46,13 +45,16 @@ class HardPhongShader(nn.Module):
|
|||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = (
|
self.cameras = cameras
|
||||||
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
texels = interpolate_vertex_colors(fragments, meshes)
|
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
if cameras is None:
|
||||||
|
msg = "Cameras must be specified either at initialization \
|
||||||
|
or in the forward pass of HardPhongShader"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
texels = interpolate_vertex_colors(fragments, meshes)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
colors = phong_shading(
|
colors = phong_shading(
|
||||||
@ -89,14 +91,16 @@ class SoftPhongShader(nn.Module):
|
|||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = (
|
self.cameras = cameras
|
||||||
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
|
|
||||||
)
|
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
texels = interpolate_vertex_colors(fragments, meshes)
|
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
if cameras is None:
|
||||||
|
msg = "Cameras must be specified either at initialization \
|
||||||
|
or in the forward pass of SoftPhongShader"
|
||||||
|
raise ValueError(msg)
|
||||||
|
texels = interpolate_vertex_colors(fragments, meshes)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
colors = phong_shading(
|
colors = phong_shading(
|
||||||
@ -132,12 +136,14 @@ class HardGouraudShader(nn.Module):
|
|||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = (
|
self.cameras = cameras
|
||||||
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
if cameras is None:
|
||||||
|
msg = "Cameras must be specified either at initialization \
|
||||||
|
or in the forward pass of SoftPhongShader"
|
||||||
|
raise ValueError(msg)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
pixel_colors = gouraud_shading(
|
pixel_colors = gouraud_shading(
|
||||||
@ -174,13 +180,15 @@ class SoftGouraudShader(nn.Module):
|
|||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = (
|
self.cameras = cameras
|
||||||
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
|
|
||||||
)
|
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
if cameras is None:
|
||||||
|
msg = "Cameras must be specified either at initialization \
|
||||||
|
or in the forward pass of SoftPhongShader"
|
||||||
|
raise ValueError(msg)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
pixel_colors = gouraud_shading(
|
pixel_colors = gouraud_shading(
|
||||||
@ -219,14 +227,16 @@ class TexturedSoftPhongShader(nn.Module):
|
|||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = (
|
self.cameras = cameras
|
||||||
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
|
|
||||||
)
|
|
||||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
texels = interpolate_texture_map(fragments, meshes)
|
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
if cameras is None:
|
||||||
|
msg = "Cameras must be specified either at initialization \
|
||||||
|
or in the forward pass of SoftPhongShader"
|
||||||
|
raise ValueError(msg)
|
||||||
|
texels = interpolate_texture_map(fragments, meshes)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
blend_params = kwargs.get("blend_params", self.blend_params)
|
blend_params = kwargs.get("blend_params", self.blend_params)
|
||||||
@ -262,13 +272,15 @@ class HardFlatShader(nn.Module):
|
|||||||
self.materials = (
|
self.materials = (
|
||||||
materials if materials is not None else Materials(device=device)
|
materials if materials is not None else Materials(device=device)
|
||||||
)
|
)
|
||||||
self.cameras = (
|
self.cameras = cameras
|
||||||
cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
texels = interpolate_vertex_colors(fragments, meshes)
|
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
if cameras is None:
|
||||||
|
msg = "Cameras must be specified either at initialization \
|
||||||
|
or in the forward pass of SoftPhongShader"
|
||||||
|
raise ValueError(msg)
|
||||||
|
texels = interpolate_vertex_colors(fragments, meshes)
|
||||||
lights = kwargs.get("lights", self.lights)
|
lights = kwargs.get("lights", self.lights)
|
||||||
materials = kwargs.get("materials", self.materials)
|
materials = kwargs.get("materials", self.materials)
|
||||||
colors = flat_shading(
|
colors = flat_shading(
|
||||||
|
@ -48,7 +48,7 @@ class PointsRasterizer(nn.Module):
|
|||||||
This class implements methods for rasterizing a batch of pointclouds.
|
This class implements methods for rasterizing a batch of pointclouds.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, cameras, raster_settings=None):
|
def __init__(self, cameras=None, raster_settings=None):
|
||||||
"""
|
"""
|
||||||
cameras: A cameras object which has a `transform_points` method
|
cameras: A cameras object which has a `transform_points` method
|
||||||
which returns the transformed points after applying the
|
which returns the transformed points after applying the
|
||||||
@ -80,6 +80,10 @@ class PointsRasterizer(nn.Module):
|
|||||||
be moved into forward.
|
be moved into forward.
|
||||||
"""
|
"""
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
if cameras is None:
|
||||||
|
msg = "Cameras must be specified either at initialization \
|
||||||
|
or in the forward pass of PointsRasterizer"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
pts_world = point_clouds.points_padded()
|
pts_world = point_clouds.points_padded()
|
||||||
pts_world_packed = point_clouds.points_packed()
|
pts_world_packed = point_clouds.points_packed()
|
||||||
|
@ -9,6 +9,11 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
||||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
|
||||||
|
from pytorch3d.renderer.points.rasterizer import (
|
||||||
|
PointsRasterizationSettings,
|
||||||
|
PointsRasterizer,
|
||||||
|
)
|
||||||
|
from pytorch3d.structures import Pointclouds
|
||||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||||
|
|
||||||
|
|
||||||
@ -99,3 +104,101 @@ class TestMeshRasterizer(unittest.TestCase):
|
|||||||
DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png"
|
DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png"
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(image, image_ref))
|
self.assertTrue(torch.allclose(image, image_ref))
|
||||||
|
|
||||||
|
#################################
|
||||||
|
# 4. Test init without cameras.
|
||||||
|
##################################
|
||||||
|
|
||||||
|
# Create a new empty rasterizer:
|
||||||
|
rasterizer = MeshRasterizer()
|
||||||
|
|
||||||
|
# Check that omitting the cameras in both initialization
|
||||||
|
# and the forward pass throws an error:
|
||||||
|
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
|
||||||
|
rasterizer(sphere_mesh)
|
||||||
|
|
||||||
|
# Now pass in the cameras as a kwarg
|
||||||
|
fragments = rasterizer(
|
||||||
|
sphere_mesh, cameras=cameras, raster_settings=raster_settings
|
||||||
|
)
|
||||||
|
image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
|
||||||
|
# Convert pix_to_face to a binary mask
|
||||||
|
image[image >= 0] = 1.0
|
||||||
|
image[image < 0] = 0.0
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / "DEBUG_test_rasterized_sphere.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(image, image_ref))
|
||||||
|
|
||||||
|
|
||||||
|
class TestPointRasterizer(unittest.TestCase):
|
||||||
|
def test_simple_sphere(self):
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
# Load reference image
|
||||||
|
ref_filename = "test_simple_pointcloud_sphere.png"
|
||||||
|
image_ref_filename = DATA_DIR / ref_filename
|
||||||
|
|
||||||
|
# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
|
||||||
|
image_ref = convert_image_to_binary_mask(image_ref_filename).to(torch.int32)
|
||||||
|
|
||||||
|
sphere_mesh = ico_sphere(1, device)
|
||||||
|
verts_padded = sphere_mesh.verts_padded()
|
||||||
|
verts_padded[..., 1] += 0.2
|
||||||
|
verts_padded[..., 0] += 0.2
|
||||||
|
pointclouds = Pointclouds(points=verts_padded)
|
||||||
|
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||||
|
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
raster_settings = PointsRasterizationSettings(
|
||||||
|
image_size=256, radius=5e-2, points_per_pixel=1
|
||||||
|
)
|
||||||
|
|
||||||
|
#################################
|
||||||
|
# 1. Test init without cameras.
|
||||||
|
##################################
|
||||||
|
|
||||||
|
# Initialize without passing in the cameras
|
||||||
|
rasterizer = PointsRasterizer()
|
||||||
|
|
||||||
|
# Check that omitting the cameras in both initialization
|
||||||
|
# and the forward pass throws an error:
|
||||||
|
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
|
||||||
|
rasterizer(pointclouds)
|
||||||
|
|
||||||
|
##########################################
|
||||||
|
# 2. Test rasterizing a single pointcloud
|
||||||
|
##########################################
|
||||||
|
|
||||||
|
fragments = rasterizer(
|
||||||
|
pointclouds, cameras=cameras, raster_settings=raster_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert idx to a binary mask
|
||||||
|
image = fragments.idx[0, ..., 0].squeeze().cpu()
|
||||||
|
image[image >= 0] = 1.0
|
||||||
|
image[image < 0] = 0.0
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / "DEBUG_test_rasterized_sphere_points.png"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(image, image_ref[..., 0]))
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# 3. Test with a batch of pointclouds
|
||||||
|
########################################
|
||||||
|
|
||||||
|
batch_size = 10
|
||||||
|
pointclouds = pointclouds.extend(batch_size)
|
||||||
|
fragments = rasterizer(
|
||||||
|
pointclouds, cameras=cameras, raster_settings=raster_settings
|
||||||
|
)
|
||||||
|
for i in range(batch_size):
|
||||||
|
image = fragments.idx[i, ..., 0].squeeze().cpu()
|
||||||
|
image[image >= 0] = 1.0
|
||||||
|
image[image < 0] = 0.0
|
||||||
|
self.assertTrue(torch.allclose(image, image_ref[..., 0]))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user