diff --git a/pytorch3d/datasets/__init__.py b/pytorch3d/datasets/__init__.py index 16872130..f35ab33c 100644 --- a/pytorch3d/datasets/__init__.py +++ b/pytorch3d/datasets/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .r2n2 import R2N2, BlenderCamera +from .r2n2 import R2N2, BlenderCamera, collate_batched_R2N2, render_cubified_voxels from .shapenet import ShapeNetCore from .utils import collate_batched_meshes diff --git a/pytorch3d/datasets/r2n2/__init__.py b/pytorch3d/datasets/r2n2/__init__.py index f18dc45d..b40be2c1 100644 --- a/pytorch3d/datasets/r2n2/__init__.py +++ b/pytorch3d/datasets/r2n2/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .r2n2 import R2N2, BlenderCamera +from .r2n2 import R2N2 +from .utils import BlenderCamera, collate_batched_R2N2, render_cubified_voxels __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index ecf3fe62..3abe4927 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -10,20 +10,32 @@ import numpy as np import torch from PIL import Image from pytorch3d.datasets.shapenet_base import ShapeNetBase -from pytorch3d.datasets.utils import compute_extrinsic_matrix from pytorch3d.io import load_obj from pytorch3d.renderer import HardPhongShader -from pytorch3d.renderer.cameras import CamerasBase -from pytorch3d.transforms import Transform3d from tabulate import tabulate +from .utils import ( + BlenderCamera, + align_bbox, + compute_extrinsic_matrix, + read_binvox_coords, + voxelize, +) + SYNSET_DICT_DIR = Path(__file__).resolve().parent - -# Default values of rotation, translation and intrinsic matrices for BlenderCamera. -r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3) -t = np.expand_dims(np.zeros(3), axis=0) # (1, 3) -k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4) +MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2. +VOXEL_SIZE = 128 +# Intrinsic matrix extracted from Blender. Taken from meshrcnn codebase: +# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py +BLENDER_INTRINSIC = torch.tensor( + [ + [2.1875, 0.0, 0.0, 0.0], + [0.0, 2.1875, 0.0, 0.0], + [0.0, 0.0, -1.002002, -0.2002002], + [0.0, 0.0, -1.0, 0.0], + ] +) class R2N2(ShapeNetBase): @@ -42,6 +54,7 @@ class R2N2(ShapeNetBase): r2n2_dir, splits_file, return_all_views: bool = True, + return_voxels: bool = False, ): """ Store each object's synset id and models id the given directories. @@ -54,6 +67,8 @@ class R2N2(ShapeNetBase): return_all_views (bool): Indicator of whether or not to load all the views in the split. If set to False, one of the views in the split will be randomly selected and loaded. + return_voxels(bool): Indicator of whether or not to return voxels as a tensor + of shape (D, D, D) where D is the number of voxels along each dimension. """ super().__init__() self.shapenet_dir = shapenet_dir @@ -83,6 +98,16 @@ class R2N2(ShapeNetBase): ) % (r2n2_dir) warnings.warn(msg) + self.return_voxels = return_voxels + # Check if the folder containing voxel coordinates is included in r2n2_dir. + if not path.isdir(path.join(r2n2_dir, "ShapeNetVox32")): + self.return_voxels = False + msg = ( + "ShapeNetVox32 not found in %s. Voxel coordinates will " + "be skipped when returning models." + ) % (r2n2_dir) + warnings.warn(msg) + synset_set = set() # Store lists of views of each model in a list. self.views_per_model_list = [] @@ -173,6 +198,8 @@ class R2N2(ShapeNetBase): - R: Rotation matrix of shape (V, 3, 3), where V is number of views returned. - T: Translation matrix of shape (V, 3), where V is number of views returned. - K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned. + - voxels: Voxels of shape (D, D, D), where D is the number of voxels along each + dimension. """ if isinstance(model_idx, tuple): model_idx, view_idxs = model_idx @@ -208,6 +235,7 @@ class R2N2(ShapeNetBase): model["label"] = self.synset_dict[model["synset_id"]] model["images"] = None + images, Rs, Ts, voxel_RTs = [], [], [], [] # Retrieve R2N2's renderings if required. if self.return_images: rendering_path = path.join( @@ -217,12 +245,9 @@ class R2N2(ShapeNetBase): model["model_id"], "rendering", ) - # Read metadata file to obtain params for calibration matrices. with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f: metadata_lines = f.readlines() - - images, Rs, Ts = [], [], [] for i in model_views: # Read image. image_path = path.join(rendering_path, "%02d.png" % i) @@ -234,9 +259,13 @@ class R2N2(ShapeNetBase): azim, elev, yaw, dist_ratio, fov = [ float(v) for v in metadata_lines[i].strip().split(" ") ] - R, T = self._compute_camera_calibration(azim, elev, dist_ratio) + dist = dist_ratio * MAX_CAMERA_DISTANCE + # Extrinsic matrix before transformation to PyTorch3D world space. + RT = compute_extrinsic_matrix(azim, elev, dist) + R, T = self._compute_camera_calibration(RT) Rs.append(R) Ts.append(T) + voxel_RTs.append(RT) # Intrinsic matrix extracted from the Blender with slight modification to work with # PyTorch3D world space. Taken from meshrcnn codebase: @@ -254,27 +283,48 @@ class R2N2(ShapeNetBase): model["T"] = torch.stack(Ts) model["K"] = K.expand(len(model_views), 4, 4) + voxels_list = [] + # Read voxels if required. + voxel_path = path.join( + self.r2n2_dir, + "ShapeNetVox32", + model["synset_id"], + model["model_id"], + "model.binvox", + ) + if self.return_voxels: + if not path.isfile(voxel_path): + msg = "Voxel file not found for model %s from category %s." + raise FileNotFoundError(msg % (model["model_id"], model["synset_id"])) + + with open(voxel_path, "rb") as f: + # Read voxel coordinates as a tensor of shape (N, 3). + voxel_coords = read_binvox_coords(f) + # Align voxels to the same coordinate system as mesh verts. + voxel_coords = align_bbox(voxel_coords, model["verts"]) + for RT in voxel_RTs: + # Compute projection matrix. + P = BLENDER_INTRINSIC.mm(RT) + # Convert voxel coordinates of shape (N, 3) to voxels of shape (D, D, D). + voxels = voxelize(voxel_coords, P, VOXEL_SIZE) + voxels_list.append(voxels) + model["voxels"] = torch.stack(voxels_list) + return model - def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float): + def _compute_camera_calibration(self, RT): """ - Helper function for calculating rotation and translation matrices from azimuth - angle, elevation and distance ratio. + Helper function for calculating rotation and translation matrices from ShapeNet + to camera transformation and ShapeNet to PyTorch3D transformation. Args: - azim: Rotation about the z-axis, in degrees. - elev: Rotation above the xy-plane, in degrees. - dist_ratio: Ratio of distance from the origin to the maximum camera distance. + RT: Extrinsic matrix that performs ShapeNet world view to camera view + transformation. Returns: - - R: Rotation matrix of shape (3, 3). - - T: Translation matrix of shape (3). + R: Rotation matrix of shape (3, 3). + T: Translation matrix of shape (3). """ - # Retrive R,T,K of the selected view(s) by reading the metadata. - MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2. - dist = dist_ratio * MAX_CAMERA_DISTANCE - RT = compute_extrinsic_matrix(azim, elev, dist) - # Transform the mesh vertices from shapenet world to pytorch3d world. shapenet_to_pytorch3d = torch.tensor( [ @@ -285,9 +335,7 @@ class R2N2(ShapeNetBase): ], dtype=torch.float32, ) - RT = compute_extrinsic_matrix(azim, elev, dist) # (4, 4) RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4) - # Extract rotation and translation matrices from RT. R = RT[:3, :3] T = RT[3, :3] @@ -348,27 +396,3 @@ class R2N2(ShapeNetBase): return super().render( idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs ) - - -class BlenderCamera(CamerasBase): - """ - Camera for rendering objects with calibration matrices from the R2N2 dataset - (which uses Blender for rendering the views for each model). - """ - - def __init__(self, R=r, T=t, K=k, device="cpu"): - """ - Args: - R: Rotation matrix of shape (N, 3, 3). - T: Translation matrix of shape (N, 3). - K: Intrinsic matrix of shape (N, 4, 4). - device: torch.device or str. - """ - # The initializer formats all inputs to torch tensors and broadcasts - # all the inputs to have the same batch dimension where necessary. - super().__init__(device=device, R=R, T=T, K=K) - - def get_projection_transform(self, **kwargs) -> Transform3d: - transform = Transform3d(device=self.device) - transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16] - return transform diff --git a/pytorch3d/datasets/r2n2/utils.py b/pytorch3d/datasets/r2n2/utils.py new file mode 100644 index 00000000..f82636b1 --- /dev/null +++ b/pytorch3d/datasets/r2n2/utils.py @@ -0,0 +1,483 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import math +from typing import Dict, List + +import numpy as np +import torch +from pytorch3d.datasets.utils import collate_batched_meshes +from pytorch3d.ops import cubify +from pytorch3d.renderer import ( + HardPhongShader, + MeshRasterizer, + MeshRenderer, + PointLights, + RasterizationSettings, + TexturesVertex, +) +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.transforms import Transform3d + + +# Empirical min and max over the dataset from meshrcnn. +# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L9 +SHAPENET_MIN_ZMIN = 0.67 +SHAPENET_MAX_ZMAX = 0.92 +# Threshold for cubify from meshrcnn: +# https://github.com/facebookresearch/meshrcnn/blob/master/configs/shapenet/voxmesh_R50.yaml#L11 +CUBIFY_THRESH = 0.2 + +# Default values of rotation, translation and intrinsic matrices for BlenderCamera. +r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3) +t = np.expand_dims(np.zeros(3), axis=0) # (1, 3) +k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4) + + +def collate_batched_R2N2(batch: List[Dict]): + """ + Take a list of objects in the form of dictionaries and merge them + into a single dictionary. This function can be used with a Dataset + object to create a torch.utils.data.Dataloader which directly + returns Meshes objects. + TODO: Add support for textures. + + Args: + batch: List of dictionaries containing information about objects + in the dataset. + + Returns: + collated_dict: Dictionary of collated lists. If batch contains both + verts and faces, a collated mesh batch is also returned. + """ + collated_dict = collate_batched_meshes(batch) + + # If collate_batched_meshes receives R2N2 items with images and that + # all models have the same number of views V, stack the batches of + # views of each model into a new batch of shape (N, V, H, W, 3). + # Otherwise leave it as a list. + if "images" in collated_dict: + try: + collated_dict["images"] = torch.stack(collated_dict["images"]) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of images instead of batches." + ) + + # If collate_batched_meshes receives R2N2 items with camera calibration + # matrices and that all models have the same number of views V, stack each + # type of matrices into a new batch of shape (N, V, ...). + # Otherwise leave them as lists. + if all(x in collated_dict for x in ["R", "T", "K"]): + try: + collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3) + collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3) + collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of calibration matrices instead of a batched tensor." + ) + + # If collate_batched_meshes receives voxels and all models have the same + # number of views V, stack the batches of voxels into a new batch of shape + # (N, V, S, S, S), where S is the voxel size. + if "voxels" in collated_dict: + try: + collated_dict["voxels"] = torch.stack(collated_dict["voxels"]) + except RuntimeError: + print( + "Models don't have the same number of views. Now returning " + "lists of voxels instead of a batched tensor." + ) + return collated_dict + + +def compute_extrinsic_matrix(azimuth, elevation, distance): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96 + + Compute 4x4 extrinsic matrix that converts from homogenous world coordinates + to homogenous camera coordinates. We assume that the camera is looking at the + origin. + Used in R2N2 Dataset when computing calibration matrices. + + Args: + azimuth: Rotation about the z-axis, in degrees. + elevation: Rotation above the xy-plane, in degrees. + distance: Distance from the origin. + + Returns: + FloatTensor of shape (4, 4). + """ + azimuth, elevation, distance = float(azimuth), float(elevation), float(distance) + + az_rad = -math.pi * azimuth / 180.0 + el_rad = -math.pi * elevation / 180.0 + sa = math.sin(az_rad) + ca = math.cos(az_rad) + se = math.sin(el_rad) + ce = math.cos(el_rad) + R_world2obj = torch.tensor( + [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]] + ) + R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + R_world2cam = R_obj2cam.mm(R_world2obj) + cam_location = torch.tensor([[distance, 0, 0]]).t() + T_world2cam = -(R_obj2cam.mm(cam_location)) + RT = torch.cat([R_world2cam, T_world2cam], dim=1) + RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])]) + + # Georgia: For some reason I cannot fathom, when Blender loads a .obj file it + # rotates the model 90 degrees about the x axis. To compensate for this quirk we + # roll that rotation into the extrinsic matrix here + rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) + RT = RT.mm(rot.to(RT)) + + return RT + + +def read_binvox_coords(f, integer_division=True, dtype=torch.float32): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L5 + + Read a binvox file and return the indices of all nonzero voxels. + + This matches the behavior of binvox_rw.read_as_coord_array + (https://github.com/dimatura/binvox-rw-py/blob/public/binvox_rw.py#L153) + but this implementation uses torch rather than numpy, and is more efficient + due to improved vectorization. + + Georgia: I think that binvox_rw.read_as_coord_array actually has a bug; when converting + linear indices into three-dimensional indices, they use floating-point + division instead of integer division. We can reproduce their incorrect + implementation by passing integer_division=False. + + Args: + f (str): A file pointer to the binvox file to read + integer_division (bool): If False, then match the buggy implementation from binvox_rw + dtype: Datatype of the output tensor. Use float64 to match binvox_rw + + Returns: + coords (tensor): A tensor of shape (N, 3) where N is the number of nonzero voxels, + and coords[i] = (x, y, z) gives the index of the ith nonzero voxel. If the + voxel grid has shape (V, V, V) then we have 0 <= x, y, z < V. + """ + size, translation, scale = _read_binvox_header(f) + storage = torch.ByteStorage.from_buffer(f.read()) + data = torch.tensor([], dtype=torch.uint8) + data.set_(source=storage) + vals, counts = data[::2], data[1::2] + idxs = _compute_idxs(vals, counts) + if not integer_division: + idxs = idxs.to(dtype) + x_idxs = idxs // (size * size) + zy_idxs = idxs % (size * size) + z_idxs = zy_idxs // size + y_idxs = zy_idxs % size + coords = torch.stack([x_idxs, y_idxs, z_idxs], dim=1) + return coords.to(dtype) + + +def _compute_idxs(vals, counts): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L58 + + Fast vectorized version of index computation. + + Args: + vals: tensor of binary values indicating voxel presence in a dense format. + counts: tensor of number of occurence of each value in vals. + + Returns: + idxs: A tensor of shape (N), where N is the number of nonzero voxels. + """ + # Consider an example where: + # vals = [0, 1, 0, 1, 1] + # counts = [2, 3, 3, 2, 1] + # + # These values of counts and vals mean that the dense binary grid is: + # [0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1] + # + # So the nonzero indices we want to return are: + # [2, 3, 4, 8, 9, 10] + + # After the cumsum we will have: + # end_idxs = [2, 5, 8, 10, 11] + end_idxs = counts.cumsum(dim=0) + + # After masking and computing start_idx we have: + # end_idxs = [5, 10, 11] + # counts = [3, 2, 1] + # start_idxs = [2, 8, 10] + mask = vals == 1 + end_idxs = end_idxs[mask] + counts = counts[mask].to(end_idxs) + start_idxs = end_idxs - counts + + # We initialize delta as: + # [2, 1, 1, 1, 1, 1] + delta = torch.ones(counts.sum().item(), dtype=torch.int64) + delta[0] = start_idxs[0] + + # We compute pos = [3, 5], val = [3, 0]; then delta is + # [2, 1, 1, 4, 1, 1] + pos = counts.cumsum(dim=0)[:-1] + val = start_idxs[1:] - end_idxs[:-1] + delta[pos] += val + + # A final cumsum gives the idx we want: [2, 3, 4, 8, 9, 10] + idxs = delta.cumsum(dim=0) + return idxs + + +def _read_binvox_header(f): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L99 + + Read binvox header and extract information regarding voxel sizes and translations + to original voxel coordinates. + + Args: + f (str): A file pointer to the binvox file to read. + + Returns: + size (int): size of voxel. + translation (tuple(float)): translation to original voxel coordinates. + scale (float): scale to original voxel coordinates. + """ + # First line of the header should be "#binvox 1" + line = f.readline().strip() + if line != b"#binvox 1": + raise ValueError("Invalid header (line 1)") + + # Second line of the header should be "dim [int] [int] [int]" + # and all three int should be the same + line = f.readline().strip() + if not line.startswith(b"dim "): + raise ValueError("Invalid header (line 2)") + dims = line.split(b" ") + try: + dims = [int(d) for d in dims[1:]] + except ValueError: + raise ValueError("Invalid header (line 2)") + if len(dims) != 3 or dims[0] != dims[1] or dims[0] != dims[2]: + raise ValueError("Invalid header (line 2)") + size = dims[0] + + # Third line of the header should be "translate [float] [float] [float]" + line = f.readline().strip() + if not line.startswith(b"translate "): + raise ValueError("Invalid header (line 3)") + translation = line.split(b" ") + if len(translation) != 4: + raise ValueError("Invalid header (line 3)") + try: + translation = tuple(float(t) for t in translation[1:]) + except ValueError: + raise ValueError("Invalid header (line 3)") + + # Fourth line of the header should be "scale [float]" + line = f.readline().strip() + if not line.startswith(b"scale "): + raise ValueError("Invalid header (line 4)") + line = line.split(b" ") + if not len(line) == 2: + raise ValueError("Invalid header (line 4)") + scale = float(line[1]) + + # Fifth line of the header should be "data" + line = f.readline().strip() + if not line == b"data": + raise ValueError("Invalid header (line 5)") + + return size, translation, scale + + +def align_bbox(src, tgt): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/tools/preprocess_shapenet.py#L263 + + Return a copy of src points in the coordinate system of tgt by applying a + scale and shift along each coordinate axis to make the min / max values align. + + Args: + src, tgt: Torch Tensor of shape (N, 3) + + Returns: + out: Torch Tensor of shape (N, 3) + """ + if src.ndim != 2 or tgt.ndim != 2: + raise ValueError("Both src and tgt need to have dimensions of 2.") + if src.shape[-1] != 3 or tgt.shape[-1] != 3: + raise ValueError( + "Both src and tgt need to have sizes of 3 along the second dimension." + ) + src_min = src.min(dim=0)[0] + src_max = src.max(dim=0)[0] + tgt_min = tgt.min(dim=0)[0] + tgt_max = tgt.max(dim=0)[0] + scale = (tgt_max - tgt_min) / (src_max - src_min) + shift = tgt_min - scale * src_min + out = scale * src + shift + return out + + +def voxelize(voxel_coords, P, V): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/tools/preprocess_shapenet.py#L284 + but changing flip y to flip x. + + Creating voxels of shape (D, D, D) from voxel_coords and projection matrix. + + Args: + voxel_coords: FloatTensor of shape (V, 3) giving voxel's coordinates aligned to + the vertices. + P: FloatTensor of shape (4, 4) giving the projection matrix. + V: Voxel size of the output. + + Returns: + voxels: Tensor of shape (D, D, D) giving the voxelized result. + """ + device = voxel_coords.device + voxel_coords = project_verts(voxel_coords, P) + + # Using the actual zmin and zmax of the model is bad because we need them + # to perform the inverse transform, which transform voxels back into world + # space for refinement or evaluation. Instead we use an empirical min and + # max over the dataset; that way it is consistent for all images. + zmin = SHAPENET_MIN_ZMIN + zmax = SHAPENET_MAX_ZMAX + + # Once we know zmin and zmax, we need to adjust the z coordinates so the + # range [zmin, zmax] instead runs from [-1, 1] + m = 2.0 / (zmax - zmin) + b = -2.0 * zmin / (zmax - zmin) - 1 + voxel_coords[:, 2].mul_(m).add_(b) + voxel_coords[:, 0].mul_(-1) # Flip x + + # Now voxels are in [-1, 1]^3; map to [0, V-1)^3 + voxel_coords = 0.5 * (V - 1) * (voxel_coords + 1.0) + voxel_coords = voxel_coords.round().to(torch.int64) + valid = (0 <= voxel_coords) * (voxel_coords < V) + valid = valid[:, 0] * valid[:, 1] * valid[:, 2] + x, y, z = voxel_coords.unbind(dim=1) + x, y, z = x[valid], y[valid], z[valid] + voxels = torch.zeros(V, V, V, dtype=torch.uint8, device=device) + voxels[z, y, x] = 1 + + return voxels + + +def project_verts(verts, P, eps=1e-1): + """ + Copied from meshrcnn codebase: + https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L159 + + Project verticies using a 4x4 transformation matrix. + + Args: + verts: FloatTensor of shape (N, V, 3) giving a batch of vertex positions or of + shape (V, 3) giving a single set of vertex positions. + P: FloatTensor of shape (N, 4, 4) giving projection matrices or of shape (4, 4) + giving a single projection matrix. + + Returns: + verts_out: FloatTensor of shape (N, V, 3) giving vertex positions (x, y, z) + where verts_out[i] is the result of transforming verts[i] by P[i]. + """ + # Handle unbatched inputs + singleton = False + if verts.dim() == 2: + assert P.dim() == 2 + singleton = True + verts, P = verts[None], P[None] + + N, V = verts.shape[0], verts.shape[1] + dtype, device = verts.dtype, verts.device + + # Add an extra row of ones to the world-space coordinates of verts before + # multiplying by the projection matrix. We could avoid this allocation by + # instead multiplying by a 4x3 submatrix of the projectio matrix, then + # adding the remaining 4x1 vector. Not sure whether there will be much + # performance difference between the two. + ones = torch.ones(N, V, 1, dtype=dtype, device=device) + verts_hom = torch.cat([verts, ones], dim=2) + verts_cam_hom = torch.bmm(verts_hom, P.transpose(1, 2)) + + # Avoid division by zero by clamping the absolute value + w = verts_cam_hom[:, :, 3:] + w_sign = w.sign() + w_sign[w == 0] = 1 + w = w_sign * w.abs().clamp(min=eps) + + verts_proj = verts_cam_hom[:, :, :3] / w + + if singleton: + return verts_proj[0] + return verts_proj + + +class BlenderCamera(CamerasBase): + """ + Camera for rendering objects with calibration matrices from the R2N2 dataset + (which uses Blender for rendering the views for each model). + """ + + def __init__(self, R=r, T=t, K=k, device="cpu"): + """ + Args: + R: Rotation matrix of shape (N, 3, 3). + T: Translation matrix of shape (N, 3). + K: Intrinsic matrix of shape (N, 4, 4). + device: torch.device or str. + """ + # The initializer formats all inputs to torch tensors and broadcasts + # all the inputs to have the same batch dimension where necessary. + super().__init__(device=device, R=R, T=T, K=K) + + def get_projection_transform(self, **kwargs) -> Transform3d: + transform = Transform3d(device=self.device) + transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16] + return transform + + +def render_cubified_voxels( + voxels: torch.Tensor, shader_type=HardPhongShader, device="cpu", **kwargs +): + """ + Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh. + + Args: + voxels: FloatTensor of shape (N, D, D, D) where N is the batch size and + D is the number of voxels along each dimension. + shader_type: shader_type: shader_type: Shader to use for rendering. Examples + include HardPhongShader (default), SoftPhongShader etc or any other type + of valid Shader class. + device: torch.device on which the tensors should be located. + **kwargs: Accepts any of the kwargs that the renderer supports. + Returns: + Batch of rendered images of shape (N, H, W, 3). + """ + cubified_voxels = cubify(voxels, CUBIFY_THRESH).to(device) + cubified_voxels.textures = TexturesVertex( + verts_features=torch.ones_like(cubified_voxels.verts_padded(), device=device) + ) + cameras = BlenderCamera(device=device) + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, + raster_settings=kwargs.get("raster_settings", RasterizationSettings()), + ), + shader=shader_type( + device=device, + cameras=cameras, + lights=kwargs.get("lights", PointLights()).to(device), + ), + ) + return renderer(cubified_voxels) diff --git a/pytorch3d/datasets/utils.py b/pytorch3d/datasets/utils.py index 43243f5c..5d8dd2ae 100644 --- a/pytorch3d/datasets/utils.py +++ b/pytorch3d/datasets/utils.py @@ -1,8 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -import math from typing import Dict, List -import torch from pytorch3d.structures import Meshes @@ -33,78 +31,4 @@ def collate_batched_meshes(batch: List[Dict]): collated_dict["mesh"] = Meshes( verts=collated_dict["verts"], faces=collated_dict["faces"] ) - - # If collate_batched_meshes receives R2N2 items with images and that - # all models have the same number of views V, stack the batches of - # views of each model into a new batch of shape (N, V, H, W, 3). - # Otherwise leave it as a list. - if "images" in collated_dict: - try: - collated_dict["images"] = torch.stack(collated_dict["images"]) - except RuntimeError: - print( - "Models don't have the same number of views. Now returning " - "lists of images instead of batches." - ) - - # If collate_batched_meshes receives R2N2 items with camera calibration - # matrices and that all models have the same number of views V, stack each - # type of matrices into a new batch of shape (N, V, ...). - # Otherwise leave them as lists. - if all(x in collated_dict for x in ["R", "T", "K"]): - try: - collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3) - collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3) - collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4) - except RuntimeError: - print( - "Models don't have the same number of views. Now returning " - "lists of calibration matrices instead of batches." - ) - return collated_dict - - -def compute_extrinsic_matrix(azimuth, elevation, distance): - """ - Copied from meshrcnn codebase: - https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96 - - Compute 4x4 extrinsic matrix that converts from homogenous world coordinates - to homogenous camera coordinates. We assume that the camera is looking at the - origin. - Used in R2N2 Dataset when computing calibration matrices. - - Args: - azimuth: Rotation about the z-axis, in degrees. - elevation: Rotation above the xy-plane, in degrees. - distance: Distance from the origin. - - Returns: - FloatTensor of shape (4, 4). - """ - azimuth, elevation, distance = float(azimuth), float(elevation), float(distance) - - az_rad = -math.pi * azimuth / 180.0 - el_rad = -math.pi * elevation / 180.0 - sa = math.sin(az_rad) - ca = math.cos(az_rad) - se = math.sin(el_rad) - ce = math.cos(el_rad) - R_world2obj = torch.tensor( - [[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]] - ) - R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) - R_world2cam = R_obj2cam.mm(R_world2obj) - cam_location = torch.tensor([[distance, 0, 0]]).t() - T_world2cam = -(R_obj2cam.mm(cam_location)) - RT = torch.cat([R_world2cam, T_world2cam], dim=1) - RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])]) - - # Georgia: For some reason I cannot fathom, when Blender loads a .obj file it - # rotates the model 90 degrees about the x axis. To compensate for this quirk we - # roll that rotation into the extrinsic matrix here - rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) - RT = RT.mm(rot.to(RT)) - - return RT diff --git a/tests/data/test_r2n2_voxel_to_mesh_render.png b/tests/data/test_r2n2_voxel_to_mesh_render.png new file mode 100644 index 00000000..d458298a Binary files /dev/null and b/tests/data/test_r2n2_voxel_to_mesh_render.png differ diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index 01ab3274..0765e699 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -11,7 +11,12 @@ import numpy as np import torch from common_testing import TestCaseMixin, load_rgb_image from PIL import Image -from pytorch3d.datasets import R2N2, BlenderCamera, collate_batched_meshes +from pytorch3d.datasets import ( + R2N2, + BlenderCamera, + collate_batched_R2N2, + render_cubified_voxels, +) from pytorch3d.renderer import ( OpenGLPerspectiveCameras, PointLights, @@ -62,8 +67,10 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): Test the loaded train split of R2N2 return items of the correct shapes and types. Also check the first image returned is correct. """ - # Load dataset in the test split. - r2n2_dataset = R2N2("test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + # Load dataset in the train split. + r2n2_dataset = R2N2( + "test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True + ) # Check total number of objects in the dataset is correct. with open(SPLITS_PATH) as splits: @@ -114,6 +121,10 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(r2n2_dataset[635]["images"].shape[0], 5) self.assertEqual(r2n2_dataset[8369]["images"].shape[0], 10) + # Check that the voxel tensor returned by __getitem__ has the correct shape. + self.assertEqual(r2n2_obj["voxels"].ndim, 4) + self.assertEqual(r2n2_obj["voxels"].shape, (24, 128, 128, 128)) + def test_collate_models(self): """ Test collate_batched_meshes returns items of the correct shapes and types. @@ -121,10 +132,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): the correct shapes and types are returned. """ # Load dataset in the train split. - r2n2_dataset = R2N2("val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH) + r2n2_dataset = R2N2( + "val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True + ) # Randomly retrieve several objects from the dataset and collate them. - collated_meshes = collate_batched_meshes( + collated_meshes = collate_batched_R2N2( [r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))] ) # Check the collated verts and faces have the correct shapes. @@ -145,7 +158,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): # in batch have the correct shape. batch_size = 12 r2n2_loader = DataLoader( - r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes + r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_R2N2 ) it = iter(r2n2_loader) object_batch = next(it) @@ -158,6 +171,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): self.assertEqual(object_batch["R"].shape[0], batch_size) self.assertEqual(object_batch["T"].shape[0], batch_size) self.assertEqual(object_batch["K"].shape[0], batch_size) + self.assertEqual(len(object_batch["voxels"]), batch_size) def test_catch_render_arg_errors(self): """ @@ -338,3 +352,24 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): "test_r2n2_render_with_blender_calibrations_%s.png" % idx, DATA_DIR ) self.assertClose(r2n2_batch_rgb, image_ref, atol=0.05) + + def test_render_voxels(self): + """ + Test rendering meshes formed from voxels. + """ + # Set up device and seed for random selections. + device = torch.device("cuda:0") + + # Load dataset in the train split with only a single view returned for each model. + r2n2_dataset = R2N2( + "train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True + ) + r2n2_model = r2n2_dataset[6, [5]] + vox_render = render_cubified_voxels(r2n2_model["voxels"], device=device) + vox_render_rgb = vox_render[0, ..., :3].squeeze().cpu() + if DEBUG: + Image.fromarray((vox_render_rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / ("DEBUG_r2n2_voxel_to_mesh_render.png") + ) + image_ref = load_rgb_image("test_r2n2_voxel_to_mesh_render.png", DATA_DIR) + self.assertClose(vox_render_rgb, image_ref, atol=0.05)