Add MeshRasterizerOpenGL

Summary:
Adding MeshRasterizerOpenGL, a faster alternative to MeshRasterizer. The new rasterizer follows the ideas from "Differentiable Surface Rendering via non-Differentiable Sampling".

The new rasterizer 20x faster on a 2M face mesh (try pose optimization on Nefertiti from https://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/!). The larger the mesh, the larger the speedup.

There are two main disadvantages:
* The new rasterizer works with an OpenGL backend, so requires pycuda.gl and pyopengl installed (though we avoided writing any C++ code, everything is in Python!)
* The new rasterizer is non-differentiable. However, you can still differentiate the rendering function if you use if with the new SplatterPhongShader which we recently added to PyTorch3D (see the original paper cited above).

Reviewed By: patricklabatut, jcjohnson

Differential Revision: D37698816

fbshipit-source-id: 54d120639d3cb001f096237807e54aced0acda25
This commit is contained in:
Krzysztof Chalupka 2022-07-22 15:52:50 -07:00 committed by Facebook GitHub Bot
parent 36edf2b302
commit cb49550486
66 changed files with 1556 additions and 337 deletions

View File

@ -66,7 +66,7 @@ from .mesh import (
)
try:
from .opengl import EGLContext, global_device_context_store
from .opengl import EGLContext, global_device_context_store, MeshRasterizerOpenGL
except (ImportError, ModuleNotFoundError):
pass # opengl or pycuda.gl not available, or pytorch3_opengl not in TARGETS.

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .clip import (
clip_faces,
ClipFrustum,

View File

@ -11,6 +11,8 @@ import numpy as np
import torch
from pytorch3d import _C
from ..utils import parse_image_size
from .clip import (
clip_faces,
ClipFrustum,
@ -149,20 +151,8 @@ def rasterize_meshes(
# If the ratio of H:W is large this might cause issues as the smaller
# dimension will have fewer bins.
# TODO: consider a better way of setting the bin size.
if isinstance(image_size, (tuple, list)):
if len(image_size) != 2:
raise ValueError("Image size can only be a tuple/list of (H, W)")
if not all(i > 0 for i in image_size):
raise ValueError(
"Image sizes must be greater than 0; got %d, %d" % image_size
)
if not all(type(i) == int for i in image_size):
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
max_image_size = max(*image_size)
im_size = image_size
else:
im_size = (image_size, image_size)
max_image_size = image_size
im_size = parse_image_size(image_size)
max_image_size = max(*im_size)
clipped_faces_neighbor_idx = None

View File

@ -57,14 +57,14 @@ class Fragments:
pix_to_face: torch.Tensor
zbuf: torch.Tensor
bary_coords: torch.Tensor
dists: torch.Tensor
dists: Optional[torch.Tensor]
def detach(self) -> "Fragments":
return Fragments(
pix_to_face=self.pix_to_face,
zbuf=self.zbuf.detach(),
bary_coords=self.bary_coords.detach(),
dists=self.dists.detach(),
dists=self.dists.detach() if self.dists is not None else self.dists,
)
@ -85,6 +85,8 @@ class RasterizationSettings:
bin_size=0 uses naive rasterization; setting bin_size=None attempts
to set it heuristically based on the shape of the input. This should
not affect the output, but can affect the speed of the forward pass.
max_faces_opengl: Max number of faces in any mesh we will rasterize. Used only by
MeshRasterizerOpenGL to pre-allocate OpenGL memory.
max_faces_per_bin: Only applicable when using coarse-to-fine
rasterization (bin_size != 0); this is the maximum number of faces
allowed within each bin. This should not affect the output values,
@ -122,6 +124,7 @@ class RasterizationSettings:
blur_radius: float = 0.0
faces_per_pixel: int = 1
bin_size: Optional[int] = None
max_faces_opengl: int = 10_000_000
max_faces_per_bin: Optional[int] = None
perspective_correct: Optional[bool] = None
clip_barycentric_coords: Optional[bool] = None
@ -237,6 +240,10 @@ class MeshRasterizer(nn.Module):
znear = znear.min().item()
z_clip = None if not perspective_correct or znear is None else znear / 2
# By default, turn on clip_barycentric_coords if blur_radius > 0.
# When blur_radius > 0, a face can be matched to a pixel that is outside the
# face, resulting in negative barycentric coordinates.
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_proj,
image_size=raster_settings.image_size,
@ -250,6 +257,10 @@ class MeshRasterizer(nn.Module):
z_clip_value=z_clip,
cull_to_frustum=raster_settings.cull_to_frustum,
)
return Fragments(
pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists
pix_to_face=pix_to_face,
zbuf=zbuf,
bary_coords=bary_coords,
dists=dists,
)

View File

@ -349,6 +349,9 @@ class SplatterPhongShader(ShaderBase):
N, H, W, K, _ = colors.shape
self.splatter_blender = SplatterBlender((N, H, W, K), colors.device)
blend_params = kwargs.get("blend_params", self.blend_params)
self.check_blend_params(blend_params)
images = self.splatter_blender(
colors,
pixel_coords_cameras,
@ -359,6 +362,14 @@ class SplatterPhongShader(ShaderBase):
return images
def check_blend_params(self, blend_params):
if blend_params.sigma != 0.5:
warnings.warn(
f"SplatterPhongShader received sigma={blend_params.sigma}. sigma is "
"defined in pixel units, and any value other than 0.5 is highly "
"unexpected. Only use other values if you know what you are doing. "
)
class HardDepthShader(ShaderBase):
"""

View File

@ -4,7 +4,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, NamedTuple, Tuple
import torch

View File

@ -32,5 +32,6 @@ def _can_import_egl_and_pycuda():
if _can_import_egl_and_pycuda():
from .opengl_utils import EGLContext, global_device_context_store
from .rasterizer_opengl import MeshRasterizerOpenGL
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@ -224,11 +224,13 @@ class EGLContext:
"""
self.lock.acquire()
egl.eglMakeCurrent(self.dpy, self.surface, self.surface, self.context)
yield
egl.eglMakeCurrent(
self.dpy, egl.EGL_NO_SURFACE, egl.EGL_NO_SURFACE, egl.EGL_NO_CONTEXT
)
self.lock.release()
try:
yield
finally:
egl.eglMakeCurrent(
self.dpy, egl.EGL_NO_SURFACE, egl.EGL_NO_SURFACE, egl.EGL_NO_CONTEXT
)
self.lock.release()
def get_context_info(self) -> Dict[str, Any]:
"""
@ -418,5 +420,29 @@ def _init_cuda_context(device_id: int = 0):
return cuda_context
def _torch_to_opengl(torch_tensor, cuda_context, cuda_buffer):
# CUDA access to the OpenGL buffer is only allowed within a map-unmap block.
cuda_context.push()
mapping_obj = cuda_buffer.map()
# data_ptr points to the OpenGL shader storage buffer memory.
data_ptr, sz = mapping_obj.device_ptr_and_size()
# Copy the torch tensor to the OpenGL buffer directly on device.
cuda_copy = cuda.Memcpy2D()
cuda_copy.set_src_device(torch_tensor.data_ptr())
cuda_copy.set_dst_device(data_ptr)
cuda_copy.width_in_bytes = cuda_copy.src_pitch = cuda_copy.dst_ptch = (
torch_tensor.shape[1] * 4
)
cuda_copy.height = torch_tensor.shape[0]
cuda_copy(False)
# Unmap and pop the cuda context to make sure OpenGL won't interfere with
# PyTorch ops down the line.
mapping_obj.unmap()
cuda_context.pop()
# Initialize a global _DeviceContextStore. Almost always we will only need a single one.
global_device_context_store = _DeviceContextStore()

View File

@ -0,0 +1,710 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# NOTE: This module (as well as rasterizer_opengl) will not be imported into pytorch3d
# if you do not have pycuda.gl and pyopengl installed. In addition, please make sure
# your Python application *does not* import OpenGL before importing PyTorch3D, unless
# you are using the EGL backend.
import warnings
from typing import Optional, Tuple, Union
import numpy as np
import OpenGL.GL as gl
import pycuda.gl
import torch
import torch.nn as nn
from pytorch3d.structures.meshes import Meshes
from ..cameras import FoVOrthographicCameras, FoVPerspectiveCameras
from ..mesh.rasterizer import Fragments, RasterizationSettings
from ..utils import parse_image_size
from .opengl_utils import _torch_to_opengl, global_device_context_store
# Shader strings, used below to compile an OpenGL program.
vertex_shader = """
// The vertex shader does nothing.
#version 430
void main() { }
"""
geometry_shader = """
#version 430
layout (points) in;
layout (triangle_strip, max_vertices = 3) out;
out layout (location = 0) vec2 bary_coords;
out layout (location = 1) float depth;
out layout (location = 2) float p2f;
layout(binding=0) buffer triangular_mesh { float mesh_buffer[]; };
uniform mat4 perspective_projection;
vec3 get_vertex_position(int vertex_index) {
int offset = gl_PrimitiveIDIn * 9 + vertex_index * 3;
return vec3(
mesh_buffer[offset],
mesh_buffer[offset + 1],
mesh_buffer[offset + 2]
);
}
void main() {
vec3 positions[3] = {
get_vertex_position(0),
get_vertex_position(1),
get_vertex_position(2)
};
vec4 projected_vertices[3] = {
perspective_projection * vec4(positions[0], 1.0),
perspective_projection * vec4(positions[1], 1.0),
perspective_projection * vec4(positions[2], 1.0)
};
for (int i = 0; i < 3; ++i) {
gl_Position = projected_vertices[i];
bary_coords = vec2(i==0 ? 1.0 : 0.0, i==1 ? 1.0 : 0.0);
// At the moment, we output depth as the distance from the image plane in
// view coordinates -- NOT distance along the camera ray.
depth = positions[i][2];
p2f = gl_PrimitiveIDIn;
EmitVertex();
}
EndPrimitive();
}
"""
fragment_shader = """
#version 430
in layout(location = 0) vec2 bary_coords;
in layout(location = 1) float depth;
in layout(location = 2) float p2f;
out vec4 bary_depth_p2f;
void main() {
bary_depth_p2f = vec4(bary_coords, depth, round(p2f));
}
"""
def _parse_and_verify_image_size(
image_size: Union[Tuple[int, int], int],
) -> Tuple[int, int]:
"""
Parse image_size as a tuple of ints. Throw ValueError if the size is incompatible
with the maximum renderable size as set in global_device_context_store.
"""
height, width = parse_image_size(image_size)
max_h = global_device_context_store.max_egl_height
max_w = global_device_context_store.max_egl_width
if height > max_h or width > max_w:
raise ValueError(
f"Max rasterization size is height={max_h}, width={max_w}. "
f"Cannot raster an image of size {height}, {width}. You can change max "
"allowed rasterization size by modifying the MAX_EGL_HEIGHT and "
"MAX_EGL_WIDTH environment variables."
)
return height, width
class MeshRasterizerOpenGL(nn.Module):
"""
EXPERIMENTAL, USE WITH CAUTION
This class implements methods for rasterizing a batch of heterogeneous
Meshes using OpenGL. This rasterizer, as opposed to MeshRasterizer, is
*not differentiable* and needs to be used with shading methods such as
SplatterPhongShader, which do not require differentiable rasterizerization.
It is, however, faster: on a 2M-faced mesh, about 20x so.
Fragments output by MeshRasterizerOpenGL and MeshRasterizer should have near
identical pix_to_face, bary_coords and zbuf. However, MeshRasterizerOpenGL does not
return Fragments.dists which is only relevant to SoftPhongShader which doesn't work
with MeshRasterizerOpenGL (because it is not differentiable).
"""
def __init__(
self,
cameras: Optional[Union[FoVOrthographicCameras, FoVPerspectiveCameras]] = None,
raster_settings=None,
) -> None:
"""
Args:
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-ndc transformations. Currently, only FoV
cameras are supported.
raster_settings: the parameters for rasterization. This should be a
named tuple.
"""
super().__init__()
if raster_settings is None:
raster_settings = RasterizationSettings()
self.raster_settings = raster_settings
_check_raster_settings(self.raster_settings)
self.cameras = cameras
self.image_size = _parse_and_verify_image_size(self.raster_settings.image_size)
self.opengl_machinery = _OpenGLMachinery(
max_faces=self.raster_settings.max_faces_opengl,
)
def forward(self, meshes_world: Meshes, **kwargs) -> Fragments:
"""
Args:
meshes_world: a Meshes object representing a batch of meshes with
coordinates in world space. The batch must live on a GPU.
Returns:
Fragments: Rasterization outputs as a named tuple. These are different than
Fragments returned by MeshRasterizer in two ways. First, we return no
`dist` which is only relevant to SoftPhongShader which doesn't work
with MeshRasterizerOpenGL (because it is not differentiable). Second,
the zbuf uses the opengl zbuf convention, where the z-vals are between 0
(at projection plane) and 1 (at clipping distance), and are a non-linear
function of the depth values of the camera ray intersections. In
contrast, MeshRasterizer's zbuf values are simply the distance of each
ray intersection from the camera.
Throws:
ValueError if meshes_world lives on the CPU.
"""
if meshes_world.device == torch.device("cpu"):
raise ValueError("MeshRasterizerOpenGL works only on CUDA devices.")
raster_settings = kwargs.get("raster_settings", self.raster_settings)
_check_raster_settings(raster_settings)
image_size = (
_parse_and_verify_image_size(raster_settings.image_size) or self.image_size
)
# OpenGL needs vertices in NDC coordinates with un-flipped xy directions.
cameras_unpacked = kwargs.get("cameras", self.cameras)
_check_cameras(cameras_unpacked)
meshes_gl_ndc = _convert_meshes_to_gl_ndc(
meshes_world, image_size, cameras_unpacked, **kwargs
)
# Perspective projection will happen within the OpenGL rasterizer.
projection_matrix = cameras_unpacked.get_projection_transform(**kwargs)._matrix
# Run OpenGL rasterization machinery.
pix_to_face, bary_coords, zbuf = self.opengl_machinery(
meshes_gl_ndc, projection_matrix, image_size
)
# Return the Fragments and detach, because gradients don't go through OpenGL.
return Fragments(
pix_to_face=pix_to_face,
zbuf=zbuf,
bary_coords=bary_coords,
dists=None,
).detach()
def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
if self.cameras is not None:
self.cameras = self.cameras.to(device)
# Create a new OpenGLMachinery, as its member variables can be tied to a GPU.
self.opengl_machinery = _OpenGLMachinery(
max_faces=self.raster_settings.max_faces_opengl,
)
class _OpenGLMachinery:
"""
A class holding OpenGL machinery used by MeshRasterizerOpenGL.
"""
def __init__(
self,
max_faces: int = 10_000_000,
) -> None:
self.max_faces = max_faces
self.program = None
# These will be created on an appropriate GPU each time we render a new mesh on
# that GPU for the first time.
self.egl_context = None
self.cuda_context = None
self.perspective_projection_uniform = None
self.mesh_buffer_object = None
self.vao = None
self.fbo = None
self.cuda_buffer = None
def __call__(
self,
meshes_gl_ndc: Meshes,
projection_matrix: torch.Tensor,
image_size: Tuple[int, int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rasterize a batch of meshes, using a given batch of projection matrices and
image size.
Args:
meshes_gl_ndc: A Meshes object, with vertices in the OpenGL NDC convention.
projection_matrix: A 3x3 camera projection matrix, or a tensor of projection
matrices equal in length to the number of meshes in meshes_gl_ndc.
image_size: Image size to rasterize. Must be smaller than the max height and
width stored in global_device_context_store.
Returns:
pix_to_faces: A BHW1 tensor of ints, filled with -1 where no face projects
to a given pixel.
bary_coords: A BHW3 float tensor, filled with -1 where no face projects to
a given pixel.
zbuf: A BHW1 float tensor, filled with 1 where no face projects to a given
pixel. NOTE: this zbuf uses the opengl zbuf convention, where the z-vals
are between 0 (at projection plane) and 1 (at clipping distance), and
are a non-linear function of the depth values of the camera ray inter-
sections.
"""
self.initialize_device_data(meshes_gl_ndc.device)
with self.egl_context.active_and_locked():
# Perspective projection happens in OpenGL. Move the matrix over if there's only
# a single camera shared by all the meshes.
if projection_matrix.shape[0] == 1:
self._projection_matrix_to_opengl(projection_matrix)
pix_to_faces = []
bary_coords = []
zbufs = []
# pyre-ignore Incompatible parameter type [6]
for mesh_id, mesh in enumerate(meshes_gl_ndc):
pix_to_face, bary_coord, zbuf = self._rasterize_mesh(
mesh,
image_size,
projection_matrix=projection_matrix[mesh_id]
if projection_matrix.shape[0] > 1
else None,
)
pix_to_faces.append(pix_to_face)
bary_coords.append(bary_coord)
zbufs.append(zbuf)
return (
torch.cat(pix_to_faces, dim=0),
torch.cat(bary_coords, dim=0),
torch.cat(zbufs, dim=0),
)
def initialize_device_data(self, device) -> None:
"""
Initialize data specific to a GPU device: the EGL and CUDA contexts, the OpenGL
program, as well as various buffer and array objects used to communicate with
OpenGL.
Args:
device: A torch.device.
"""
self.egl_context = global_device_context_store.get_egl_context(device)
self.cuda_context = global_device_context_store.get_cuda_context(device)
# self.program represents the OpenGL program we use for rasterization.
if global_device_context_store.get_context_data(device) is None:
with self.egl_context.active_and_locked():
self.program = self._compile_and_link_gl_program()
self._set_up_gl_program_properties(self.program)
# Create objects used to transfer data into and out of the program.
(
self.perspective_projection_uniform,
self.mesh_buffer_object,
self.vao,
self.fbo,
) = self._prepare_persistent_opengl_objects(
self.program,
self.max_faces,
)
# Register the input buffer with pycuda, to transfer data directly into it.
self.cuda_context.push()
self.cuda_buffer = pycuda.gl.RegisteredBuffer(
int(self.mesh_buffer_object),
pycuda.gl.graphics_map_flags.WRITE_DISCARD,
)
self.cuda_context.pop()
global_device_context_store.set_context_data(
device,
(
self.program,
self.perspective_projection_uniform,
self.mesh_buffer_object,
self.vao,
self.fbo,
self.cuda_buffer,
),
)
(
self.program,
self.perspective_projection_uniform,
self.mesh_buffer_object,
self.vao,
self.fbo,
self.cuda_buffer,
) = global_device_context_store.get_context_data(device)
def release(self) -> None:
"""
Release CUDA and OpenGL resources.
"""
# Finish all current operations.
torch.cuda.synchronize()
self.cuda_context.synchronize()
# Free pycuda resources.
self.cuda_context.push()
self.cuda_buffer.unregister()
self.cuda_context.pop()
# Free GL resources.
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo)
gl.glDeleteFramebuffers(1, [self.fbo])
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
del self.fbo
gl.glBindBufferBase(gl.GL_SHADER_STORAGE_BUFFER, 0, self.mesh_buffer_object)
gl.glDeleteBuffers(1, [self.mesh_buffer_object])
gl.glBindBufferBase(gl.GL_SHADER_STORAGE_BUFFER, 0, 0)
del self.mesh_buffer_object
gl.glDeleteProgram(self.program)
self.egl_context.release()
def _projection_matrix_to_opengl(self, projection_matrix: torch.Tensor) -> None:
"""
Transfer a torch projection matrix to OpenGL.
Args:
projection matrix: A 3x3 float tensor.
"""
gl.glUseProgram(self.program)
gl.glUniformMatrix4fv(
self.perspective_projection_uniform,
1,
gl.GL_FALSE,
projection_matrix.detach().flatten().cpu().numpy().astype(np.float32),
)
gl.glUseProgram(0)
def _rasterize_mesh(
self,
mesh: Meshes,
image_size: Tuple[int, int],
projection_matrix: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rasterize a single mesh using OpenGL.
Args:
mesh: A Meshes object, containing a single mesh only.
projection_matrix: A 3x3 camera projection matrix, or a tensor of projection
matrices equal in length to the number of meshes in meshes_gl_ndc.
image_size: Image size to rasterize. Must be smaller than the max height and
width stored in global_device_context_store.
Returns:
pix_to_faces: A 1HW1 tensor of ints, filled with -1 where no face projects
to a given pixel.
bary_coords: A 1HW3 float tensor, filled with -1 where no face projects to
a given pixel.
zbuf: A 1HW1 float tensor, filled with 1 where no face projects to a given
pixel. NOTE: this zbuf uses the opengl zbuf convention, where the z-vals
are between 0 (at projection plane) and 1 (at clipping distance), and
are a non-linear function of the depth values of the camera ray inter-
sections.
"""
height, width = image_size
# Extract face_verts and move them to OpenGL as well. We use pycuda to
# directly move the vertices on the GPU, to avoid a costly torch/GPU -> CPU
# -> openGL/GPU trip.
verts_packed = mesh.verts_packed().detach()
faces_packed = mesh.faces_packed().detach()
face_verts = verts_packed[faces_packed].reshape(-1, 9)
_torch_to_opengl(face_verts, self.cuda_context, self.cuda_buffer)
if projection_matrix is not None:
self._projection_matrix_to_opengl(projection_matrix)
# Start OpenGL operations.
gl.glUseProgram(self.program)
# Render an image of size (width, height).
gl.glViewport(0, 0, width, height)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, self.fbo)
# Clear the output framebuffer. The "background" value for both pix_to_face
# as well as bary_coords is -1 (background = pixels which the rasterizer
# projected no triangle to).
gl.glClearColor(-1.0, -1.0, -1.0, -1.0)
gl.glClearDepth(1.0)
# pyre-ignore Unsupported operand [58]
gl.glClear(gl.GL_COLOR_BUFFER_BIT | gl.GL_DEPTH_BUFFER_BIT)
# Run the actual rendering. The face_verts were transported to the OpenGL
# program into a shader storage buffer which is used directly in the geometry
# shader. Here, we only pass the number of these vertices to the vertex shader
# (which doesn't do anything and passes directly to the geometry shader).
gl.glBindVertexArray(self.vao)
gl.glDrawArrays(gl.GL_POINTS, 0, len(face_verts))
gl.glBindVertexArray(0)
# Read out the result. We ignore the depth buffer. The RGBA color buffer stores
# barycentrics in the RGB component and pix_to_face in the A component.
bary_depth_p2f_gl = gl.glReadPixels(
0,
0,
width,
height,
gl.GL_RGBA,
gl.GL_FLOAT,
)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
# Create torch tensors containing the results.
bary_depth_p2f = (
torch.frombuffer(bary_depth_p2f_gl, dtype=torch.float)
.reshape(1, height, width, 1, -1)
.to(verts_packed.device)
)
# Read out barycentrics. GL only outputs the first two, so we need to compute
# the third one and make sure we still leave no-intersection pixels with -1.
barycentric_coords = torch.cat(
[
bary_depth_p2f[..., :2],
1.0 - bary_depth_p2f[..., 0:1] - bary_depth_p2f[..., 1:2],
],
dim=-1,
)
barycentric_coords = torch.where(
barycentric_coords == 3, -1, barycentric_coords
)
depth = bary_depth_p2f[..., 2:3].squeeze(-1)
pix_to_face = bary_depth_p2f[..., -1].long()
return pix_to_face, barycentric_coords, depth
@staticmethod
def _compile_and_link_gl_program():
"""
Compile the vertex, geometry, and fragment shaders and link them into an OpenGL
program. The shader sources are strongly inspired by https://github.com/tensorflow/
graphics/blob/master/tensorflow_graphics/rendering/opengl/rasterization_backend.py.
Returns:
An OpenGL program for mesh rasterization.
"""
program = gl.glCreateProgram()
shader_objects = []
for shader_string, shader_type in zip(
[vertex_shader, geometry_shader, fragment_shader],
[gl.GL_VERTEX_SHADER, gl.GL_GEOMETRY_SHADER, gl.GL_FRAGMENT_SHADER],
):
shader_objects.append(gl.glCreateShader(shader_type))
gl.glShaderSource(shader_objects[-1], shader_string)
gl.glCompileShader(shader_objects[-1])
status = gl.glGetShaderiv(shader_objects[-1], gl.GL_COMPILE_STATUS)
if status == gl.GL_FALSE:
gl.glDeleteShader(shader_objects[-1])
gl.glDeleteProgram(program)
error_msg = gl.glGetShaderInfoLog(shader_objects[-1]).decode("utf-8")
raise RuntimeError(f"Compilation failure:\n {error_msg}")
gl.glAttachShader(program, shader_objects[-1])
gl.glDeleteShader(shader_objects[-1])
gl.glLinkProgram(program)
status = gl.glGetProgramiv(program, gl.GL_LINK_STATUS)
if status == gl.GL_FALSE:
gl.glDeleteProgram(program)
error_msg = gl.glGetProgramInfoLog(program)
raise RuntimeError(f"Link failure:\n {error_msg}")
return program
@staticmethod
def _set_up_gl_program_properties(program) -> None:
"""
Set basic OpenGL program properties: disable blending, enable depth testing,
and disable face culling.
"""
gl.glUseProgram(program)
gl.glDisable(gl.GL_BLEND)
gl.glEnable(gl.GL_DEPTH_TEST)
gl.glDisable(gl.GL_CULL_FACE)
gl.glUseProgram(0)
@staticmethod
def _prepare_persistent_opengl_objects(program, max_faces: int):
"""
Prepare OpenGL objects that we want to persist between rasterizations.
Args:
program: The OpenGL program the resources will be tied to.
max_faces: Max number of faces of any mesh we will rasterize.
Returns:
perspective_projection_uniform: An OpenGL object pointing to a location of
the perspective projection matrix in OpenGL memory.
mesh_buffer_object: An OpenGL object pointing to the location of the mesh
buffer object in OpenGL memory.
vao: The OpenGL input array object.
fbo: The OpenGL output framebuffer.
"""
gl.glUseProgram(program)
# Get location of the "uniform" (that is, an internal OpenGL variable available
# to the shaders) that we'll load the projection matrices to.
perspective_projection_uniform = gl.glGetUniformLocation(
program, "perspective_projection"
)
# Mesh buffer object -- our main input point. We'll copy the mesh here
# from pytorch/cuda. The buffer needs enough space to store the three vertices
# of each face, that is its size in bytes is
# max_faces * 3 (vertices) * 3 (coordinates) * 4 (bytes)
mesh_buffer_object = gl.glGenBuffers(1)
gl.glBindBufferBase(gl.GL_SHADER_STORAGE_BUFFER, 0, mesh_buffer_object)
gl.glBufferData(
gl.GL_SHADER_STORAGE_BUFFER,
max_faces * 9 * 4,
np.zeros((max_faces, 9), dtype=np.float32),
gl.GL_DYNAMIC_COPY,
)
# Input vertex array object. We will only use it implicitly for indexing the
# vertices, but the actual input data is passed in the shader storage buffer.
vao = gl.glGenVertexArrays(1)
# Create the framebuffer object (fbo) where we'll store output data.
MAX_EGL_WIDTH = global_device_context_store.max_egl_width
MAX_EGL_HEIGHT = global_device_context_store.max_egl_height
color_buffer = gl.glGenRenderbuffers(1)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, color_buffer)
gl.glRenderbufferStorage(
gl.GL_RENDERBUFFER, gl.GL_RGBA32F, MAX_EGL_WIDTH, MAX_EGL_HEIGHT
)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, 0)
depth_buffer = gl.glGenRenderbuffers(1)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, depth_buffer)
gl.glRenderbufferStorage(
gl.GL_RENDERBUFFER, gl.GL_DEPTH_COMPONENT, MAX_EGL_WIDTH, MAX_EGL_HEIGHT
)
gl.glBindRenderbuffer(gl.GL_RENDERBUFFER, 0)
fbo = gl.glGenFramebuffers(1)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo)
gl.glFramebufferRenderbuffer(
gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_RENDERBUFFER, color_buffer
)
gl.glFramebufferRenderbuffer(
gl.GL_FRAMEBUFFER, gl.GL_DEPTH_ATTACHMENT, gl.GL_RENDERBUFFER, depth_buffer
)
gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0)
gl.glUseProgram(0)
return perspective_projection_uniform, mesh_buffer_object, vao, fbo
def _check_cameras(cameras) -> None:
# Check that the cameras are non-None and compatible with MeshRasterizerOpenGL.
if cameras is None:
msg = "Cameras must be specified either at initialization \
or in the forward pass of MeshRasterizer"
raise ValueError(msg)
if type(cameras).__name__ in {"PerspectiveCameras", "OrthographicCameras"}:
raise ValueError(
"MeshRasterizerOpenGL only works with FoVPerspectiveCameras and "
"FoVOrthographicCameras, which are OpenGL compatible."
)
def _check_raster_settings(raster_settings) -> None:
# Check that the rasterizer's settings are compatible with MeshRasterizerOpenGL.
if raster_settings.faces_per_pixel > 1:
warnings.warn(
"MeshRasterizerOpenGL currently works only with one face per pixel."
)
if raster_settings.cull_backfaces:
warnings.warn(
"MeshRasterizerOpenGL cannot cull backfaces yet, rasterizing without culling."
)
if raster_settings.cull_to_frustum:
warnings.warn(
"MeshRasterizerOpenGL cannot cull to frustum yet, rasterizing without culling."
)
if raster_settings.z_clip_value is not None:
raise NotImplementedError("MeshRasterizerOpenGL cannot do z-clipping yet.")
if raster_settings.perspective_correct is False:
raise ValueError(
"MeshRasterizerOpenGL always uses perspective-correct interpolation."
)
def _convert_meshes_to_gl_ndc(
meshes_world: Meshes, image_size: Tuple[int, int], camera, **kwargs
) -> Meshes:
"""
Convert a batch of world-coordinate meshes to GL NDC coordinates.
Args:
meshes_world: Meshes in the world coordinate system.
image_size: Image height and width, used to modify mesh coords for rendering in
non-rectangular images. OpenGL will expand anything within the [-1, 1] NDC
range to fit the width and height of the screen, so we will squeeze the NDCs
appropriately if rendering a rectangular image.
camera: FoV cameras.
kwargs['R'], kwargs['T']: If present, used to define the world-view transform.
"""
height, width = image_size
verts_ndc = (
camera.get_world_to_view_transform(**kwargs)
.compose(camera.get_ndc_camera_transform(**kwargs))
.transform_points(meshes_world.verts_padded(), eps=None)
)
verts_ndc[..., 0] = -verts_ndc[..., 0]
verts_ndc[..., 1] = -verts_ndc[..., 1]
# In case of a non-square viewport, transform the vertices. OpenGL will expand
# the anything within the [-1, 1] NDC range to fit the width and height of the
# screen. So to work with PyTorch3D cameras, we need to squeeze the NDCs
# appropriately.
dtype, device = verts_ndc.dtype, verts_ndc.device
if height > width:
verts_ndc = verts_ndc * torch.tensor(
[1, width / height, 1], dtype=dtype, device=device
)
elif width > height:
verts_ndc = verts_ndc * torch.tensor(
[height / width, 1, 1], dtype=dtype, device=device
)
meshes_gl_ndc = meshes_world.update_padded(new_verts_padded=verts_ndc)
return meshes_gl_ndc

View File

@ -11,6 +11,8 @@ import torch
from pytorch3d import _C
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_non_square_ndc
from ..utils import parse_image_size
# Maximum number of faces per bins for
# coarse-to-fine rasterization
@ -102,20 +104,8 @@ def rasterize_points(
# If the ratio of H:W is large this might cause issues as the smaller
# dimension will have fewer bins.
# TODO: consider a better way of setting the bin size.
if isinstance(image_size, (tuple, list)):
if len(image_size) != 2:
raise ValueError("Image size can only be a tuple/list of (H, W)")
if not all(i > 0 for i in image_size):
raise ValueError(
"Image sizes must be greater than 0; got %d, %d" % image_size
)
if not all(type(i) == int for i in image_size):
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
max_image_size = max(*image_size)
im_size = image_size
else:
im_size = (image_size, image_size)
max_image_size = image_size
im_size = parse_image_size(image_size)
max_image_size = max(*im_size)
if bin_size is None:
if not points_packed.is_cuda:

View File

@ -8,7 +8,7 @@
import copy
import inspect
import warnings
from typing import Any, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import numpy as np
import torch
@ -432,3 +432,27 @@ def ndc_to_grid_sample_coords(
else:
xy_grid_sample[..., 0] *= aspect
return xy_grid_sample
def parse_image_size(
image_size: Union[List[int], Tuple[int, int], int]
) -> Tuple[int, int]:
"""
Args:
image_size: A single int (for square images) or a tuple/list of two ints.
Returns:
A tuple of two ints.
Throws:
ValueError if got more than two ints, any negative numbers or non-ints.
"""
if not isinstance(image_size, (tuple, list)):
return (image_size, image_size)
if len(image_size) != 2:
raise ValueError("Image size can only be a tuple/list of (H, W)")
if not all(i > 0 for i in image_size):
raise ValueError("Image sizes must be greater than 0; got %d, %d" % image_size)
if not all(type(i) == int for i in image_size):
raise ValueError("Image sizes must be integers; got %f, %f" % image_size)
return tuple(image_size)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 97 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 120 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 35 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

View File

Before

Width:  |  Height:  |  Size: 25 KiB

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

View File

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 12 KiB

View File

Before

Width:  |  Height:  |  Size: 11 KiB

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View File

Before

Width:  |  Height:  |  Size: 2.2 KiB

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 568 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 568 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 29 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 10 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.3 KiB

View File

Before

Width:  |  Height:  |  Size: 758 B

After

Width:  |  Height:  |  Size: 758 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 758 B

View File

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

Before

Width:  |  Height:  |  Size: 31 KiB

After

Width:  |  Height:  |  Size: 31 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

View File

Before

Width:  |  Height:  |  Size: 30 KiB

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 32 KiB

View File

@ -26,8 +26,15 @@ class TestBuild(unittest.TestCase):
sys.modules.pop(module, None)
root_dir = get_pytorch3d_dir() / "pytorch3d"
# Exclude opengl-related files, as Implicitron is decoupled from opengl
# components which will not work without adding a dep on pytorch3d_opengl.
for module_file in root_dir.glob("**/*.py"):
if module_file.stem in ("__init__", "plotly_vis", "opengl_utils"):
if module_file.stem in (
"__init__",
"plotly_vis",
"opengl_utils",
"rasterizer_opengl",
):
continue
relative_module = str(module_file.relative_to(root_dir))[:-3]
module = "pytorch3d." + relative_module.replace("/", ".")

View File

@ -11,15 +11,18 @@ import numpy as np
import torch
from PIL import Image
from pytorch3d.io import load_obj
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.lighting import PointLights
from pytorch3d.renderer.materials import Materials
from pytorch3d.renderer.mesh import (
from pytorch3d.renderer import (
BlendParams,
FoVPerspectiveCameras,
look_at_view_transform,
Materials,
MeshRasterizer,
MeshRasterizerOpenGL,
MeshRenderer,
PointLights,
RasterizationSettings,
SoftPhongShader,
SplatterPhongShader,
TexturesUV,
)
from pytorch3d.renderer.mesh.rasterize_meshes import (
@ -454,6 +457,12 @@ class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase):
)
def test_render_cow(self):
self._render_cow(MeshRasterizer)
def test_render_cow_opengl(self):
self._render_cow(MeshRasterizerOpenGL)
def _render_cow(self, rasterizer_type):
"""
Test a larger textured mesh is rendered correctly in a non square image.
"""
@ -473,38 +482,55 @@ class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase):
mesh = Meshes(verts=[verts], faces=[faces.verts_idx], textures=textures)
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 180)
R, T = look_at_view_transform(1.2, 0, 90)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=(512, 1024), blur_radius=0.0, faces_per_pixel=1
image_size=(500, 800), blur_radius=0.0, faces_per_pixel=1
)
# Init shader settings
materials = Materials(device=device)
lights = PointLights(device=device)
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
# Init renderer
renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=SoftPhongShader(
rasterizer = rasterizer_type(cameras=cameras, raster_settings=raster_settings)
if rasterizer_type == MeshRasterizer:
blend_params = BlendParams(
sigma=1e-1,
gamma=1e-4,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
shader = SoftPhongShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
),
)
)
else:
blend_params = BlendParams(
sigma=0.5,
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
)
shader = SplatterPhongShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
# Load reference image
image_ref = load_rgb_image("test_cow_image_rectangle.png", DATA_DIR)
image_ref = load_rgb_image(
f"test_cow_image_rectangle_{rasterizer_type.__name__}.png", DATA_DIR
)
for bin_size in [0, None]:
if bin_size == 0 and rasterizer_type == MeshRasterizerOpenGL:
continue
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(mesh)
@ -512,7 +538,8 @@ class TestRasterizeRectangleImagesMeshes(TestCaseMixin, unittest.TestCase):
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_cow_image_rectangle.png"
DATA_DIR
/ f"DEBUG_cow_image_rectangle_{rasterizer_type.__name__}.png"
)
# NOTE some pixels can be flaky

View File

@ -10,16 +10,29 @@ import unittest
import numpy as np
import torch
from PIL import Image
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
from pytorch3d.renderer.points.rasterizer import (
from pytorch3d.renderer import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
look_at_view_transform,
MeshRasterizer,
MeshRasterizerOpenGL,
OrthographicCameras,
PerspectiveCameras,
PointsRasterizationSettings,
PointsRasterizer,
RasterizationSettings,
)
from pytorch3d.renderer.opengl.rasterizer_opengl import (
_check_cameras,
_check_raster_settings,
_convert_meshes_to_gl_ndc,
_parse_and_verify_image_size,
)
from pytorch3d.structures import Pointclouds
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere
from .common_testing import get_tests_dir
from .common_testing import get_tests_dir, TestCaseMixin
DATA_DIR = get_tests_dir() / "data"
@ -36,8 +49,14 @@ def convert_image_to_binary_mask(filename):
class TestMeshRasterizer(unittest.TestCase):
def test_simple_sphere(self):
self._simple_sphere(MeshRasterizer)
def test_simple_sphere_opengl(self):
self._simple_sphere(MeshRasterizerOpenGL)
def _simple_sphere(self, rasterizer_type):
device = torch.device("cuda:0")
ref_filename = "test_rasterized_sphere.png"
ref_filename = f"test_rasterized_sphere_{rasterizer_type.__name__}.png"
image_ref_filename = DATA_DIR / ref_filename
# Rescale image_ref to the 0 - 1 range and convert to a binary mask.
@ -54,7 +73,7 @@ class TestMeshRasterizer(unittest.TestCase):
)
# Init rasterizer
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
rasterizer = rasterizer_type(cameras=cameras, raster_settings=raster_settings)
####################################
# 1. Test rasterizing a single mesh
@ -68,7 +87,8 @@ class TestMeshRasterizer(unittest.TestCase):
if DEBUG:
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_sphere.png"
DATA_DIR
/ f"DEBUG_test_rasterized_sphere_{rasterizer_type.__name__}.png"
)
self.assertTrue(torch.allclose(image, image_ref))
@ -90,20 +110,21 @@ class TestMeshRasterizer(unittest.TestCase):
# 3. Test that passing kwargs to rasterizer works.
####################################################
# Change the view transform to zoom in.
R, T = look_at_view_transform(2.0, 0, 0, device=device)
# Change the view transform to zoom out.
R, T = look_at_view_transform(20.0, 0, 0, device=device)
fragments = rasterizer(sphere_mesh, R=R, T=T)
image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
image[image >= 0] = 1.0
image[image < 0] = 0.0
ref_filename = "test_rasterized_sphere_zoom.png"
ref_filename = f"test_rasterized_sphere_zoom_{rasterizer_type.__name__}.png"
image_ref_filename = DATA_DIR / ref_filename
image_ref = convert_image_to_binary_mask(image_ref_filename)
if DEBUG:
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png"
DATA_DIR
/ f"DEBUG_test_rasterized_sphere_zoom_{rasterizer_type.__name__}.png"
)
self.assertTrue(torch.allclose(image, image_ref))
@ -112,7 +133,7 @@ class TestMeshRasterizer(unittest.TestCase):
##################################
# Create a new empty rasterizer:
rasterizer = MeshRasterizer()
rasterizer = rasterizer_type(raster_settings=raster_settings)
# Check that omitting the cameras in both initialization
# and the forward pass throws an error:
@ -120,9 +141,7 @@ class TestMeshRasterizer(unittest.TestCase):
rasterizer(sphere_mesh)
# Now pass in the cameras as a kwarg
fragments = rasterizer(
sphere_mesh, cameras=cameras, raster_settings=raster_settings
)
fragments = rasterizer(sphere_mesh, cameras=cameras)
image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
# Convert pix_to_face to a binary mask
image[image >= 0] = 1.0
@ -130,7 +149,8 @@ class TestMeshRasterizer(unittest.TestCase):
if DEBUG:
Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_test_rasterized_sphere.png"
DATA_DIR
/ f"DEBUG_test_rasterized_sphere_{rasterizer_type.__name__}.png"
)
self.assertTrue(torch.allclose(image, image_ref))
@ -141,6 +161,187 @@ class TestMeshRasterizer(unittest.TestCase):
rasterizer = MeshRasterizer()
rasterizer.to(device)
rasterizer = MeshRasterizerOpenGL()
rasterizer.to(device)
def test_compare_rasterizers(self):
device = torch.device("cuda:0")
# Init rasterizer settings
R, T = look_at_view_transform(2.7, 0, 0)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=0.0,
faces_per_pixel=1,
bin_size=0,
perspective_correct=True,
)
from pytorch3d.io import load_obj
from pytorch3d.renderer import TexturesAtlas
from .common_testing import get_pytorch3d_dir
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
obj_filename = TUTORIAL_DATA_DIR / "cow_mesh/cow.obj"
# Load mesh and texture as a per face texture atlas.
verts, faces, aux = load_obj(
obj_filename,
device=device,
load_textures=True,
create_texture_atlas=True,
texture_atlas_size=8,
texture_wrap=None,
)
atlas = aux.texture_atlas
mesh = Meshes(
verts=[verts],
faces=[faces.verts_idx],
textures=TexturesAtlas(atlas=[atlas]),
)
# Rasterize using both rasterizers and compare results.
rasterizer = MeshRasterizerOpenGL(
cameras=cameras, raster_settings=raster_settings
)
fragments_opengl = rasterizer(mesh)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
fragments = rasterizer(mesh)
# Ensure that 99.9% of bary_coords is at most 0.001 different.
self.assertLess(
torch.quantile(
(fragments.bary_coords - fragments_opengl.bary_coords).abs(), 0.999
),
0.001,
)
# Ensure that 99.9% of zbuf vals is at most 0.001 different.
self.assertLess(
torch.quantile((fragments.zbuf - fragments_opengl.zbuf).abs(), 0.999), 0.001
)
# Ensure that 99.99% of pix_to_face is identical.
self.assertEqual(
torch.quantile(
(fragments.pix_to_face != fragments_opengl.pix_to_face).float(), 0.9999
),
0,
)
class TestMeshRasterizerOpenGLUtils(TestCaseMixin, unittest.TestCase):
def setUp(self):
verts = torch.tensor(
[[-1, 1, 0], [1, 1, 0], [1, -1, 0]], dtype=torch.float32
).cuda()
faces = torch.tensor([[0, 1, 2]]).cuda()
self.meshes_world = Meshes(verts=[verts], faces=[faces])
# Test various utils specific to the OpenGL rasterizer. Full "integration tests"
# live in test_render_meshes and test_render_multigpu.
def test_check_cameras(self):
_check_cameras(FoVPerspectiveCameras())
_check_cameras(FoVPerspectiveCameras())
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
_check_cameras(None)
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
_check_cameras(PerspectiveCameras())
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
_check_cameras(OrthographicCameras())
MeshRasterizerOpenGL(FoVPerspectiveCameras().cuda())(self.meshes_world)
MeshRasterizerOpenGL(FoVOrthographicCameras().cuda())(self.meshes_world)
MeshRasterizerOpenGL()(
self.meshes_world, cameras=FoVPerspectiveCameras().cuda()
)
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
MeshRasterizerOpenGL(PerspectiveCameras().cuda())(self.meshes_world)
with self.assertRaisesRegex(ValueError, "MeshRasterizerOpenGL only works with"):
MeshRasterizerOpenGL(OrthographicCameras().cuda())(self.meshes_world)
with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
MeshRasterizerOpenGL()(self.meshes_world)
def test_check_raster_settings(self):
raster_settings = RasterizationSettings()
raster_settings.faces_per_pixel = 100
with self.assertWarnsRegex(UserWarning, ".* one face per pixel"):
_check_raster_settings(raster_settings)
with self.assertWarnsRegex(UserWarning, ".* one face per pixel"):
MeshRasterizerOpenGL(raster_settings=raster_settings)(
self.meshes_world, cameras=FoVPerspectiveCameras().cuda()
)
def test_convert_meshes_to_gl_ndc_square_img(self):
R, T = look_at_view_transform(1, 90, 180)
cameras = FoVOrthographicCameras(R=R, T=T).cuda()
meshes_gl_ndc = _convert_meshes_to_gl_ndc(
self.meshes_world, (100, 100), cameras
)
# After look_at_view_transform rotating 180 deg around z-axis, we recover
# the original coordinates. After additionally elevating the view by 90
# deg, we "zero out" the y-coordinate. Finally, we negate the x and y axes
# to adhere to OpenGL conventions (which go against the PyTorch3D convention).
self.assertClose(
meshes_gl_ndc.verts_list()[0],
torch.tensor(
[[-1, 0, 0], [1, 0, 0], [1, 0, 2]], dtype=torch.float32
).cuda(),
atol=1e-5,
)
def test_parse_and_verify_image_size(self):
img_size = _parse_and_verify_image_size(512)
self.assertEqual(img_size, (512, 512))
img_size = _parse_and_verify_image_size((2047, 10))
self.assertEqual(img_size, (2047, 10))
img_size = _parse_and_verify_image_size((10, 2047))
self.assertEqual(img_size, (10, 2047))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
_parse_and_verify_image_size((2049, 512))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
_parse_and_verify_image_size((512, 2049))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
_parse_and_verify_image_size((2049, 2049))
rasterizer = MeshRasterizerOpenGL(FoVPerspectiveCameras().cuda())
raster_settings = RasterizationSettings()
raster_settings.image_size = 512
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 512, 512, 1]))
raster_settings.image_size = (2047, 10)
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 2047, 10, 1]))
raster_settings.image_size = (10, 2047)
fragments = rasterizer(self.meshes_world, raster_settings=raster_settings)
self.assertEqual(fragments.pix_to_face.shape, torch.Size([1, 10, 2047, 1]))
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (2049, 512)
rasterizer(self.meshes_world, raster_settings=raster_settings)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (512, 2049)
rasterizer(self.meshes_world, raster_settings=raster_settings)
with self.assertRaisesRegex(ValueError, "Max rasterization size is"):
raster_settings.image_size = (2049, 2049)
rasterizer(self.meshes_world, raster_settings=raster_settings)
class TestPointRasterizer(unittest.TestCase):
def test_simple_sphere(self):

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ from pytorch3d.renderer import (
HardGouraudShader,
Materials,
MeshRasterizer,
MeshRasterizerOpenGL,
MeshRenderer,
PointLights,
PointsRasterizationSettings,
@ -21,18 +22,19 @@ from pytorch3d.renderer import (
PointsRenderer,
RasterizationSettings,
SoftPhongShader,
SplatterPhongShader,
TexturesVertex,
)
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere
from .common_testing import get_random_cuda_device, TestCaseMixin
from .common_testing import TestCaseMixin
# Set the number of GPUS you want to test with
NUM_GPUS = 3
GPU_LIST = list({get_random_cuda_device() for _ in range(NUM_GPUS)})
NUM_GPUS = 2
GPU_LIST = [f"cuda:{idx}" for idx in range(NUM_GPUS)]
print("GPUs: %s" % ", ".join(GPU_LIST))
@ -56,12 +58,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
self.assertEqual(renderer.shader.materials.device, device)
self.assertEqual(renderer.shader.materials.ambient_color.device, device)
def test_mesh_renderer_to(self):
def _mesh_renderer_to(self, rasterizer_class, shader_class):
"""
Test moving all the tensors in the mesh renderer to a new device.
"""
device1 = torch.device("cpu")
device1 = torch.device("cuda:0")
R, T = look_at_view_transform(1500, 0.0, 0.0)
@ -71,12 +73,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None]
raster_settings = RasterizationSettings(
image_size=256, blur_radius=0.0, faces_per_pixel=1
image_size=128, blur_radius=0.0, faces_per_pixel=1
)
cameras = FoVPerspectiveCameras(
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
rasterizer = rasterizer_class(cameras=cameras, raster_settings=raster_settings)
blend_params = BlendParams(
1e-4,
@ -84,7 +86,7 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
background_color=torch.zeros(3, dtype=torch.float32, device=device1),
)
shader = SoftPhongShader(
shader = shader_class(
lights=lights,
cameras=cameras,
materials=materials,
@ -107,26 +109,32 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
# Move renderer and mesh to another device and re render
# This also tests that background_color is correctly moved to
# the new device
device2 = torch.device("cuda:0")
device2 = torch.device("cuda:1")
renderer = renderer.to(device2)
mesh = mesh.to(device2)
self._check_mesh_renderer_props_on_device(renderer, device2)
output_images = renderer(mesh)
self.assertEqual(output_images.device, device2)
def test_render_meshes(self):
def test_mesh_renderer_to(self):
self._mesh_renderer_to(MeshRasterizer, SoftPhongShader)
def test_mesh_renderer_opengl_to(self):
self._mesh_renderer_to(MeshRasterizerOpenGL, SplatterPhongShader)
def _render_meshes(self, rasterizer_class, shader_class):
test = self
class Model(nn.Module):
def __init__(self):
def __init__(self, device):
super(Model, self).__init__()
mesh = ico_sphere(3)
mesh = ico_sphere(3).to(device)
self.register_buffer("faces", mesh.faces_padded())
self.renderer = self.init_render()
self.renderer = self.init_render(device)
def init_render(self):
def init_render(self, device):
cameras = FoVPerspectiveCameras()
cameras = FoVPerspectiveCameras().to(device)
raster_settings = RasterizationSettings(
image_size=128, blur_radius=0.0, faces_per_pixel=1
)
@ -135,12 +143,12 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
diffuse_color=((0, 0.0, 0),),
specular_color=((0.0, 0, 0),),
location=((0.0, 0.0, 1e5),),
)
).to(device)
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
rasterizer=rasterizer_class(
cameras=cameras, raster_settings=raster_settings
),
shader=HardGouraudShader(cameras=cameras, lights=lights),
shader=shader_class(cameras=cameras, lights=lights),
)
return renderer
@ -155,20 +163,25 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
img_render = self.renderer(mesh)
return img_render[:, :, :, :3]
# DataParallel requires every input tensor be provided
# on the first device in its device_ids list.
verts = ico_sphere(3).verts_padded()
# Make sure we use all GPUs in GPU_LIST by making the batch size 4 x GPU count.
verts = ico_sphere(3).verts_padded().expand(len(GPU_LIST) * 4, 642, 3)
texs = verts.new_ones(verts.shape)
model = Model()
model.to(GPU_LIST[0])
model = Model(device=GPU_LIST[0])
model = nn.DataParallel(model, device_ids=GPU_LIST)
# Test a few iterations
for _ in range(100):
model(verts, texs)
def test_render_meshes(self):
self._render_meshes(MeshRasterizer, HardGouraudShader)
class TestRenderPointssMultiGPU(TestCaseMixin, unittest.TestCase):
# @unittest.skip("Multi-GPU OpenGL training is currently not supported.")
def test_render_meshes_opengl(self):
self._render_meshes(MeshRasterizerOpenGL, SplatterPhongShader)
class TestRenderPointsMultiGPU(TestCaseMixin, unittest.TestCase):
def _check_points_renderer_props_on_device(self, renderer, device):
"""
Helper function to check that all the properties have