mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Return R2N2 voxel coordinates
Summary: Return R2N2's voxel coordinates. Reviewed By: nikhilaravi Differential Revision: D22462530 fbshipit-source-id: a995cfa0957b2561eb3b0f4591cb1db42170bc68
This commit is contained in:
parent
326e4ccb5b
commit
63ba74f1a8
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .r2n2 import R2N2, BlenderCamera
|
from .r2n2 import R2N2, BlenderCamera, collate_batched_R2N2, render_cubified_voxels
|
||||||
from .shapenet import ShapeNetCore
|
from .shapenet import ShapeNetCore
|
||||||
from .utils import collate_batched_meshes
|
from .utils import collate_batched_meshes
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .r2n2 import R2N2, BlenderCamera
|
from .r2n2 import R2N2
|
||||||
|
from .utils import BlenderCamera, collate_batched_R2N2, render_cubified_voxels
|
||||||
|
|
||||||
|
|
||||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||||
|
@ -10,20 +10,32 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||||
from pytorch3d.datasets.utils import compute_extrinsic_matrix
|
|
||||||
from pytorch3d.io import load_obj
|
from pytorch3d.io import load_obj
|
||||||
from pytorch3d.renderer import HardPhongShader
|
from pytorch3d.renderer import HardPhongShader
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
|
||||||
from pytorch3d.transforms import Transform3d
|
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
BlenderCamera,
|
||||||
|
align_bbox,
|
||||||
|
compute_extrinsic_matrix,
|
||||||
|
read_binvox_coords,
|
||||||
|
voxelize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||||
|
MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2.
|
||||||
# Default values of rotation, translation and intrinsic matrices for BlenderCamera.
|
VOXEL_SIZE = 128
|
||||||
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
|
# Intrinsic matrix extracted from Blender. Taken from meshrcnn codebase:
|
||||||
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
|
# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py
|
||||||
k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4)
|
BLENDER_INTRINSIC = torch.tensor(
|
||||||
|
[
|
||||||
|
[2.1875, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 2.1875, 0.0, 0.0],
|
||||||
|
[0.0, 0.0, -1.002002, -0.2002002],
|
||||||
|
[0.0, 0.0, -1.0, 0.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class R2N2(ShapeNetBase):
|
class R2N2(ShapeNetBase):
|
||||||
@ -42,6 +54,7 @@ class R2N2(ShapeNetBase):
|
|||||||
r2n2_dir,
|
r2n2_dir,
|
||||||
splits_file,
|
splits_file,
|
||||||
return_all_views: bool = True,
|
return_all_views: bool = True,
|
||||||
|
return_voxels: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Store each object's synset id and models id the given directories.
|
Store each object's synset id and models id the given directories.
|
||||||
@ -54,6 +67,8 @@ class R2N2(ShapeNetBase):
|
|||||||
return_all_views (bool): Indicator of whether or not to load all the views in
|
return_all_views (bool): Indicator of whether or not to load all the views in
|
||||||
the split. If set to False, one of the views in the split will be randomly
|
the split. If set to False, one of the views in the split will be randomly
|
||||||
selected and loaded.
|
selected and loaded.
|
||||||
|
return_voxels(bool): Indicator of whether or not to return voxels as a tensor
|
||||||
|
of shape (D, D, D) where D is the number of voxels along each dimension.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shapenet_dir = shapenet_dir
|
self.shapenet_dir = shapenet_dir
|
||||||
@ -83,6 +98,16 @@ class R2N2(ShapeNetBase):
|
|||||||
) % (r2n2_dir)
|
) % (r2n2_dir)
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
|
|
||||||
|
self.return_voxels = return_voxels
|
||||||
|
# Check if the folder containing voxel coordinates is included in r2n2_dir.
|
||||||
|
if not path.isdir(path.join(r2n2_dir, "ShapeNetVox32")):
|
||||||
|
self.return_voxels = False
|
||||||
|
msg = (
|
||||||
|
"ShapeNetVox32 not found in %s. Voxel coordinates will "
|
||||||
|
"be skipped when returning models."
|
||||||
|
) % (r2n2_dir)
|
||||||
|
warnings.warn(msg)
|
||||||
|
|
||||||
synset_set = set()
|
synset_set = set()
|
||||||
# Store lists of views of each model in a list.
|
# Store lists of views of each model in a list.
|
||||||
self.views_per_model_list = []
|
self.views_per_model_list = []
|
||||||
@ -173,6 +198,8 @@ class R2N2(ShapeNetBase):
|
|||||||
- R: Rotation matrix of shape (V, 3, 3), where V is number of views returned.
|
- R: Rotation matrix of shape (V, 3, 3), where V is number of views returned.
|
||||||
- T: Translation matrix of shape (V, 3), where V is number of views returned.
|
- T: Translation matrix of shape (V, 3), where V is number of views returned.
|
||||||
- K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned.
|
- K: Intrinsic matrix of shape (V, 4, 4), where V is number of views returned.
|
||||||
|
- voxels: Voxels of shape (D, D, D), where D is the number of voxels along each
|
||||||
|
dimension.
|
||||||
"""
|
"""
|
||||||
if isinstance(model_idx, tuple):
|
if isinstance(model_idx, tuple):
|
||||||
model_idx, view_idxs = model_idx
|
model_idx, view_idxs = model_idx
|
||||||
@ -208,6 +235,7 @@ class R2N2(ShapeNetBase):
|
|||||||
model["label"] = self.synset_dict[model["synset_id"]]
|
model["label"] = self.synset_dict[model["synset_id"]]
|
||||||
|
|
||||||
model["images"] = None
|
model["images"] = None
|
||||||
|
images, Rs, Ts, voxel_RTs = [], [], [], []
|
||||||
# Retrieve R2N2's renderings if required.
|
# Retrieve R2N2's renderings if required.
|
||||||
if self.return_images:
|
if self.return_images:
|
||||||
rendering_path = path.join(
|
rendering_path = path.join(
|
||||||
@ -217,12 +245,9 @@ class R2N2(ShapeNetBase):
|
|||||||
model["model_id"],
|
model["model_id"],
|
||||||
"rendering",
|
"rendering",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Read metadata file to obtain params for calibration matrices.
|
# Read metadata file to obtain params for calibration matrices.
|
||||||
with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f:
|
with open(path.join(rendering_path, "rendering_metadata.txt"), "r") as f:
|
||||||
metadata_lines = f.readlines()
|
metadata_lines = f.readlines()
|
||||||
|
|
||||||
images, Rs, Ts = [], [], []
|
|
||||||
for i in model_views:
|
for i in model_views:
|
||||||
# Read image.
|
# Read image.
|
||||||
image_path = path.join(rendering_path, "%02d.png" % i)
|
image_path = path.join(rendering_path, "%02d.png" % i)
|
||||||
@ -234,9 +259,13 @@ class R2N2(ShapeNetBase):
|
|||||||
azim, elev, yaw, dist_ratio, fov = [
|
azim, elev, yaw, dist_ratio, fov = [
|
||||||
float(v) for v in metadata_lines[i].strip().split(" ")
|
float(v) for v in metadata_lines[i].strip().split(" ")
|
||||||
]
|
]
|
||||||
R, T = self._compute_camera_calibration(azim, elev, dist_ratio)
|
dist = dist_ratio * MAX_CAMERA_DISTANCE
|
||||||
|
# Extrinsic matrix before transformation to PyTorch3D world space.
|
||||||
|
RT = compute_extrinsic_matrix(azim, elev, dist)
|
||||||
|
R, T = self._compute_camera_calibration(RT)
|
||||||
Rs.append(R)
|
Rs.append(R)
|
||||||
Ts.append(T)
|
Ts.append(T)
|
||||||
|
voxel_RTs.append(RT)
|
||||||
|
|
||||||
# Intrinsic matrix extracted from the Blender with slight modification to work with
|
# Intrinsic matrix extracted from the Blender with slight modification to work with
|
||||||
# PyTorch3D world space. Taken from meshrcnn codebase:
|
# PyTorch3D world space. Taken from meshrcnn codebase:
|
||||||
@ -254,27 +283,48 @@ class R2N2(ShapeNetBase):
|
|||||||
model["T"] = torch.stack(Ts)
|
model["T"] = torch.stack(Ts)
|
||||||
model["K"] = K.expand(len(model_views), 4, 4)
|
model["K"] = K.expand(len(model_views), 4, 4)
|
||||||
|
|
||||||
|
voxels_list = []
|
||||||
|
# Read voxels if required.
|
||||||
|
voxel_path = path.join(
|
||||||
|
self.r2n2_dir,
|
||||||
|
"ShapeNetVox32",
|
||||||
|
model["synset_id"],
|
||||||
|
model["model_id"],
|
||||||
|
"model.binvox",
|
||||||
|
)
|
||||||
|
if self.return_voxels:
|
||||||
|
if not path.isfile(voxel_path):
|
||||||
|
msg = "Voxel file not found for model %s from category %s."
|
||||||
|
raise FileNotFoundError(msg % (model["model_id"], model["synset_id"]))
|
||||||
|
|
||||||
|
with open(voxel_path, "rb") as f:
|
||||||
|
# Read voxel coordinates as a tensor of shape (N, 3).
|
||||||
|
voxel_coords = read_binvox_coords(f)
|
||||||
|
# Align voxels to the same coordinate system as mesh verts.
|
||||||
|
voxel_coords = align_bbox(voxel_coords, model["verts"])
|
||||||
|
for RT in voxel_RTs:
|
||||||
|
# Compute projection matrix.
|
||||||
|
P = BLENDER_INTRINSIC.mm(RT)
|
||||||
|
# Convert voxel coordinates of shape (N, 3) to voxels of shape (D, D, D).
|
||||||
|
voxels = voxelize(voxel_coords, P, VOXEL_SIZE)
|
||||||
|
voxels_list.append(voxels)
|
||||||
|
model["voxels"] = torch.stack(voxels_list)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _compute_camera_calibration(self, azim: float, elev: float, dist_ratio: float):
|
def _compute_camera_calibration(self, RT):
|
||||||
"""
|
"""
|
||||||
Helper function for calculating rotation and translation matrices from azimuth
|
Helper function for calculating rotation and translation matrices from ShapeNet
|
||||||
angle, elevation and distance ratio.
|
to camera transformation and ShapeNet to PyTorch3D transformation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
azim: Rotation about the z-axis, in degrees.
|
RT: Extrinsic matrix that performs ShapeNet world view to camera view
|
||||||
elev: Rotation above the xy-plane, in degrees.
|
transformation.
|
||||||
dist_ratio: Ratio of distance from the origin to the maximum camera distance.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- R: Rotation matrix of shape (3, 3).
|
R: Rotation matrix of shape (3, 3).
|
||||||
- T: Translation matrix of shape (3).
|
T: Translation matrix of shape (3).
|
||||||
"""
|
"""
|
||||||
# Retrive R,T,K of the selected view(s) by reading the metadata.
|
|
||||||
MAX_CAMERA_DISTANCE = 1.75 # Constant from R2N2.
|
|
||||||
dist = dist_ratio * MAX_CAMERA_DISTANCE
|
|
||||||
RT = compute_extrinsic_matrix(azim, elev, dist)
|
|
||||||
|
|
||||||
# Transform the mesh vertices from shapenet world to pytorch3d world.
|
# Transform the mesh vertices from shapenet world to pytorch3d world.
|
||||||
shapenet_to_pytorch3d = torch.tensor(
|
shapenet_to_pytorch3d = torch.tensor(
|
||||||
[
|
[
|
||||||
@ -285,9 +335,7 @@ class R2N2(ShapeNetBase):
|
|||||||
],
|
],
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
RT = compute_extrinsic_matrix(azim, elev, dist) # (4, 4)
|
|
||||||
RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4)
|
RT = torch.transpose(RT, 0, 1).mm(shapenet_to_pytorch3d) # (4, 4)
|
||||||
|
|
||||||
# Extract rotation and translation matrices from RT.
|
# Extract rotation and translation matrices from RT.
|
||||||
R = RT[:3, :3]
|
R = RT[:3, :3]
|
||||||
T = RT[3, :3]
|
T = RT[3, :3]
|
||||||
@ -348,27 +396,3 @@ class R2N2(ShapeNetBase):
|
|||||||
return super().render(
|
return super().render(
|
||||||
idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs
|
idxs=idxs, shader_type=shader_type, device=device, cameras=cameras, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BlenderCamera(CamerasBase):
|
|
||||||
"""
|
|
||||||
Camera for rendering objects with calibration matrices from the R2N2 dataset
|
|
||||||
(which uses Blender for rendering the views for each model).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, R=r, T=t, K=k, device="cpu"):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
R: Rotation matrix of shape (N, 3, 3).
|
|
||||||
T: Translation matrix of shape (N, 3).
|
|
||||||
K: Intrinsic matrix of shape (N, 4, 4).
|
|
||||||
device: torch.device or str.
|
|
||||||
"""
|
|
||||||
# The initializer formats all inputs to torch tensors and broadcasts
|
|
||||||
# all the inputs to have the same batch dimension where necessary.
|
|
||||||
super().__init__(device=device, R=R, T=T, K=K)
|
|
||||||
|
|
||||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
|
||||||
transform = Transform3d(device=self.device)
|
|
||||||
transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16]
|
|
||||||
return transform
|
|
||||||
|
483
pytorch3d/datasets/r2n2/utils.py
Normal file
483
pytorch3d/datasets/r2n2/utils.py
Normal file
@ -0,0 +1,483 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
import math
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pytorch3d.datasets.utils import collate_batched_meshes
|
||||||
|
from pytorch3d.ops import cubify
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
HardPhongShader,
|
||||||
|
MeshRasterizer,
|
||||||
|
MeshRenderer,
|
||||||
|
PointLights,
|
||||||
|
RasterizationSettings,
|
||||||
|
TexturesVertex,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
from pytorch3d.transforms import Transform3d
|
||||||
|
|
||||||
|
|
||||||
|
# Empirical min and max over the dataset from meshrcnn.
|
||||||
|
# https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L9
|
||||||
|
SHAPENET_MIN_ZMIN = 0.67
|
||||||
|
SHAPENET_MAX_ZMAX = 0.92
|
||||||
|
# Threshold for cubify from meshrcnn:
|
||||||
|
# https://github.com/facebookresearch/meshrcnn/blob/master/configs/shapenet/voxmesh_R50.yaml#L11
|
||||||
|
CUBIFY_THRESH = 0.2
|
||||||
|
|
||||||
|
# Default values of rotation, translation and intrinsic matrices for BlenderCamera.
|
||||||
|
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
|
||||||
|
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
|
||||||
|
k = np.expand_dims(np.eye(4), axis=0) # (1, 4, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def collate_batched_R2N2(batch: List[Dict]):
|
||||||
|
"""
|
||||||
|
Take a list of objects in the form of dictionaries and merge them
|
||||||
|
into a single dictionary. This function can be used with a Dataset
|
||||||
|
object to create a torch.utils.data.Dataloader which directly
|
||||||
|
returns Meshes objects.
|
||||||
|
TODO: Add support for textures.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: List of dictionaries containing information about objects
|
||||||
|
in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
collated_dict: Dictionary of collated lists. If batch contains both
|
||||||
|
verts and faces, a collated mesh batch is also returned.
|
||||||
|
"""
|
||||||
|
collated_dict = collate_batched_meshes(batch)
|
||||||
|
|
||||||
|
# If collate_batched_meshes receives R2N2 items with images and that
|
||||||
|
# all models have the same number of views V, stack the batches of
|
||||||
|
# views of each model into a new batch of shape (N, V, H, W, 3).
|
||||||
|
# Otherwise leave it as a list.
|
||||||
|
if "images" in collated_dict:
|
||||||
|
try:
|
||||||
|
collated_dict["images"] = torch.stack(collated_dict["images"])
|
||||||
|
except RuntimeError:
|
||||||
|
print(
|
||||||
|
"Models don't have the same number of views. Now returning "
|
||||||
|
"lists of images instead of batches."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If collate_batched_meshes receives R2N2 items with camera calibration
|
||||||
|
# matrices and that all models have the same number of views V, stack each
|
||||||
|
# type of matrices into a new batch of shape (N, V, ...).
|
||||||
|
# Otherwise leave them as lists.
|
||||||
|
if all(x in collated_dict for x in ["R", "T", "K"]):
|
||||||
|
try:
|
||||||
|
collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3)
|
||||||
|
collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3)
|
||||||
|
collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4)
|
||||||
|
except RuntimeError:
|
||||||
|
print(
|
||||||
|
"Models don't have the same number of views. Now returning "
|
||||||
|
"lists of calibration matrices instead of a batched tensor."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If collate_batched_meshes receives voxels and all models have the same
|
||||||
|
# number of views V, stack the batches of voxels into a new batch of shape
|
||||||
|
# (N, V, S, S, S), where S is the voxel size.
|
||||||
|
if "voxels" in collated_dict:
|
||||||
|
try:
|
||||||
|
collated_dict["voxels"] = torch.stack(collated_dict["voxels"])
|
||||||
|
except RuntimeError:
|
||||||
|
print(
|
||||||
|
"Models don't have the same number of views. Now returning "
|
||||||
|
"lists of voxels instead of a batched tensor."
|
||||||
|
)
|
||||||
|
return collated_dict
|
||||||
|
|
||||||
|
|
||||||
|
def compute_extrinsic_matrix(azimuth, elevation, distance):
|
||||||
|
"""
|
||||||
|
Copied from meshrcnn codebase:
|
||||||
|
https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96
|
||||||
|
|
||||||
|
Compute 4x4 extrinsic matrix that converts from homogenous world coordinates
|
||||||
|
to homogenous camera coordinates. We assume that the camera is looking at the
|
||||||
|
origin.
|
||||||
|
Used in R2N2 Dataset when computing calibration matrices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
azimuth: Rotation about the z-axis, in degrees.
|
||||||
|
elevation: Rotation above the xy-plane, in degrees.
|
||||||
|
distance: Distance from the origin.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FloatTensor of shape (4, 4).
|
||||||
|
"""
|
||||||
|
azimuth, elevation, distance = float(azimuth), float(elevation), float(distance)
|
||||||
|
|
||||||
|
az_rad = -math.pi * azimuth / 180.0
|
||||||
|
el_rad = -math.pi * elevation / 180.0
|
||||||
|
sa = math.sin(az_rad)
|
||||||
|
ca = math.cos(az_rad)
|
||||||
|
se = math.sin(el_rad)
|
||||||
|
ce = math.cos(el_rad)
|
||||||
|
R_world2obj = torch.tensor(
|
||||||
|
[[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]]
|
||||||
|
)
|
||||||
|
R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||||
|
R_world2cam = R_obj2cam.mm(R_world2obj)
|
||||||
|
cam_location = torch.tensor([[distance, 0, 0]]).t()
|
||||||
|
T_world2cam = -(R_obj2cam.mm(cam_location))
|
||||||
|
RT = torch.cat([R_world2cam, T_world2cam], dim=1)
|
||||||
|
RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])])
|
||||||
|
|
||||||
|
# Georgia: For some reason I cannot fathom, when Blender loads a .obj file it
|
||||||
|
# rotates the model 90 degrees about the x axis. To compensate for this quirk we
|
||||||
|
# roll that rotation into the extrinsic matrix here
|
||||||
|
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
||||||
|
RT = RT.mm(rot.to(RT))
|
||||||
|
|
||||||
|
return RT
|
||||||
|
|
||||||
|
|
||||||
|
def read_binvox_coords(f, integer_division=True, dtype=torch.float32):
|
||||||
|
"""
|
||||||
|
Copied from meshrcnn codebase:
|
||||||
|
https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L5
|
||||||
|
|
||||||
|
Read a binvox file and return the indices of all nonzero voxels.
|
||||||
|
|
||||||
|
This matches the behavior of binvox_rw.read_as_coord_array
|
||||||
|
(https://github.com/dimatura/binvox-rw-py/blob/public/binvox_rw.py#L153)
|
||||||
|
but this implementation uses torch rather than numpy, and is more efficient
|
||||||
|
due to improved vectorization.
|
||||||
|
|
||||||
|
Georgia: I think that binvox_rw.read_as_coord_array actually has a bug; when converting
|
||||||
|
linear indices into three-dimensional indices, they use floating-point
|
||||||
|
division instead of integer division. We can reproduce their incorrect
|
||||||
|
implementation by passing integer_division=False.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f (str): A file pointer to the binvox file to read
|
||||||
|
integer_division (bool): If False, then match the buggy implementation from binvox_rw
|
||||||
|
dtype: Datatype of the output tensor. Use float64 to match binvox_rw
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
coords (tensor): A tensor of shape (N, 3) where N is the number of nonzero voxels,
|
||||||
|
and coords[i] = (x, y, z) gives the index of the ith nonzero voxel. If the
|
||||||
|
voxel grid has shape (V, V, V) then we have 0 <= x, y, z < V.
|
||||||
|
"""
|
||||||
|
size, translation, scale = _read_binvox_header(f)
|
||||||
|
storage = torch.ByteStorage.from_buffer(f.read())
|
||||||
|
data = torch.tensor([], dtype=torch.uint8)
|
||||||
|
data.set_(source=storage)
|
||||||
|
vals, counts = data[::2], data[1::2]
|
||||||
|
idxs = _compute_idxs(vals, counts)
|
||||||
|
if not integer_division:
|
||||||
|
idxs = idxs.to(dtype)
|
||||||
|
x_idxs = idxs // (size * size)
|
||||||
|
zy_idxs = idxs % (size * size)
|
||||||
|
z_idxs = zy_idxs // size
|
||||||
|
y_idxs = zy_idxs % size
|
||||||
|
coords = torch.stack([x_idxs, y_idxs, z_idxs], dim=1)
|
||||||
|
return coords.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_idxs(vals, counts):
|
||||||
|
"""
|
||||||
|
Copied from meshrcnn codebase:
|
||||||
|
https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L58
|
||||||
|
|
||||||
|
Fast vectorized version of index computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vals: tensor of binary values indicating voxel presence in a dense format.
|
||||||
|
counts: tensor of number of occurence of each value in vals.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
idxs: A tensor of shape (N), where N is the number of nonzero voxels.
|
||||||
|
"""
|
||||||
|
# Consider an example where:
|
||||||
|
# vals = [0, 1, 0, 1, 1]
|
||||||
|
# counts = [2, 3, 3, 2, 1]
|
||||||
|
#
|
||||||
|
# These values of counts and vals mean that the dense binary grid is:
|
||||||
|
# [0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1]
|
||||||
|
#
|
||||||
|
# So the nonzero indices we want to return are:
|
||||||
|
# [2, 3, 4, 8, 9, 10]
|
||||||
|
|
||||||
|
# After the cumsum we will have:
|
||||||
|
# end_idxs = [2, 5, 8, 10, 11]
|
||||||
|
end_idxs = counts.cumsum(dim=0)
|
||||||
|
|
||||||
|
# After masking and computing start_idx we have:
|
||||||
|
# end_idxs = [5, 10, 11]
|
||||||
|
# counts = [3, 2, 1]
|
||||||
|
# start_idxs = [2, 8, 10]
|
||||||
|
mask = vals == 1
|
||||||
|
end_idxs = end_idxs[mask]
|
||||||
|
counts = counts[mask].to(end_idxs)
|
||||||
|
start_idxs = end_idxs - counts
|
||||||
|
|
||||||
|
# We initialize delta as:
|
||||||
|
# [2, 1, 1, 1, 1, 1]
|
||||||
|
delta = torch.ones(counts.sum().item(), dtype=torch.int64)
|
||||||
|
delta[0] = start_idxs[0]
|
||||||
|
|
||||||
|
# We compute pos = [3, 5], val = [3, 0]; then delta is
|
||||||
|
# [2, 1, 1, 4, 1, 1]
|
||||||
|
pos = counts.cumsum(dim=0)[:-1]
|
||||||
|
val = start_idxs[1:] - end_idxs[:-1]
|
||||||
|
delta[pos] += val
|
||||||
|
|
||||||
|
# A final cumsum gives the idx we want: [2, 3, 4, 8, 9, 10]
|
||||||
|
idxs = delta.cumsum(dim=0)
|
||||||
|
return idxs
|
||||||
|
|
||||||
|
|
||||||
|
def _read_binvox_header(f):
|
||||||
|
"""
|
||||||
|
Copied from meshrcnn codebase:
|
||||||
|
https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/binvox_torch.py#L99
|
||||||
|
|
||||||
|
Read binvox header and extract information regarding voxel sizes and translations
|
||||||
|
to original voxel coordinates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f (str): A file pointer to the binvox file to read.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
size (int): size of voxel.
|
||||||
|
translation (tuple(float)): translation to original voxel coordinates.
|
||||||
|
scale (float): scale to original voxel coordinates.
|
||||||
|
"""
|
||||||
|
# First line of the header should be "#binvox 1"
|
||||||
|
line = f.readline().strip()
|
||||||
|
if line != b"#binvox 1":
|
||||||
|
raise ValueError("Invalid header (line 1)")
|
||||||
|
|
||||||
|
# Second line of the header should be "dim [int] [int] [int]"
|
||||||
|
# and all three int should be the same
|
||||||
|
line = f.readline().strip()
|
||||||
|
if not line.startswith(b"dim "):
|
||||||
|
raise ValueError("Invalid header (line 2)")
|
||||||
|
dims = line.split(b" ")
|
||||||
|
try:
|
||||||
|
dims = [int(d) for d in dims[1:]]
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Invalid header (line 2)")
|
||||||
|
if len(dims) != 3 or dims[0] != dims[1] or dims[0] != dims[2]:
|
||||||
|
raise ValueError("Invalid header (line 2)")
|
||||||
|
size = dims[0]
|
||||||
|
|
||||||
|
# Third line of the header should be "translate [float] [float] [float]"
|
||||||
|
line = f.readline().strip()
|
||||||
|
if not line.startswith(b"translate "):
|
||||||
|
raise ValueError("Invalid header (line 3)")
|
||||||
|
translation = line.split(b" ")
|
||||||
|
if len(translation) != 4:
|
||||||
|
raise ValueError("Invalid header (line 3)")
|
||||||
|
try:
|
||||||
|
translation = tuple(float(t) for t in translation[1:])
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Invalid header (line 3)")
|
||||||
|
|
||||||
|
# Fourth line of the header should be "scale [float]"
|
||||||
|
line = f.readline().strip()
|
||||||
|
if not line.startswith(b"scale "):
|
||||||
|
raise ValueError("Invalid header (line 4)")
|
||||||
|
line = line.split(b" ")
|
||||||
|
if not len(line) == 2:
|
||||||
|
raise ValueError("Invalid header (line 4)")
|
||||||
|
scale = float(line[1])
|
||||||
|
|
||||||
|
# Fifth line of the header should be "data"
|
||||||
|
line = f.readline().strip()
|
||||||
|
if not line == b"data":
|
||||||
|
raise ValueError("Invalid header (line 5)")
|
||||||
|
|
||||||
|
return size, translation, scale
|
||||||
|
|
||||||
|
|
||||||
|
def align_bbox(src, tgt):
|
||||||
|
"""
|
||||||
|
Copied from meshrcnn codebase:
|
||||||
|
https://github.com/facebookresearch/meshrcnn/blob/master/tools/preprocess_shapenet.py#L263
|
||||||
|
|
||||||
|
Return a copy of src points in the coordinate system of tgt by applying a
|
||||||
|
scale and shift along each coordinate axis to make the min / max values align.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
src, tgt: Torch Tensor of shape (N, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: Torch Tensor of shape (N, 3)
|
||||||
|
"""
|
||||||
|
if src.ndim != 2 or tgt.ndim != 2:
|
||||||
|
raise ValueError("Both src and tgt need to have dimensions of 2.")
|
||||||
|
if src.shape[-1] != 3 or tgt.shape[-1] != 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Both src and tgt need to have sizes of 3 along the second dimension."
|
||||||
|
)
|
||||||
|
src_min = src.min(dim=0)[0]
|
||||||
|
src_max = src.max(dim=0)[0]
|
||||||
|
tgt_min = tgt.min(dim=0)[0]
|
||||||
|
tgt_max = tgt.max(dim=0)[0]
|
||||||
|
scale = (tgt_max - tgt_min) / (src_max - src_min)
|
||||||
|
shift = tgt_min - scale * src_min
|
||||||
|
out = scale * src + shift
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def voxelize(voxel_coords, P, V):
|
||||||
|
"""
|
||||||
|
Copied from meshrcnn codebase:
|
||||||
|
https://github.com/facebookresearch/meshrcnn/blob/master/tools/preprocess_shapenet.py#L284
|
||||||
|
but changing flip y to flip x.
|
||||||
|
|
||||||
|
Creating voxels of shape (D, D, D) from voxel_coords and projection matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voxel_coords: FloatTensor of shape (V, 3) giving voxel's coordinates aligned to
|
||||||
|
the vertices.
|
||||||
|
P: FloatTensor of shape (4, 4) giving the projection matrix.
|
||||||
|
V: Voxel size of the output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
voxels: Tensor of shape (D, D, D) giving the voxelized result.
|
||||||
|
"""
|
||||||
|
device = voxel_coords.device
|
||||||
|
voxel_coords = project_verts(voxel_coords, P)
|
||||||
|
|
||||||
|
# Using the actual zmin and zmax of the model is bad because we need them
|
||||||
|
# to perform the inverse transform, which transform voxels back into world
|
||||||
|
# space for refinement or evaluation. Instead we use an empirical min and
|
||||||
|
# max over the dataset; that way it is consistent for all images.
|
||||||
|
zmin = SHAPENET_MIN_ZMIN
|
||||||
|
zmax = SHAPENET_MAX_ZMAX
|
||||||
|
|
||||||
|
# Once we know zmin and zmax, we need to adjust the z coordinates so the
|
||||||
|
# range [zmin, zmax] instead runs from [-1, 1]
|
||||||
|
m = 2.0 / (zmax - zmin)
|
||||||
|
b = -2.0 * zmin / (zmax - zmin) - 1
|
||||||
|
voxel_coords[:, 2].mul_(m).add_(b)
|
||||||
|
voxel_coords[:, 0].mul_(-1) # Flip x
|
||||||
|
|
||||||
|
# Now voxels are in [-1, 1]^3; map to [0, V-1)^3
|
||||||
|
voxel_coords = 0.5 * (V - 1) * (voxel_coords + 1.0)
|
||||||
|
voxel_coords = voxel_coords.round().to(torch.int64)
|
||||||
|
valid = (0 <= voxel_coords) * (voxel_coords < V)
|
||||||
|
valid = valid[:, 0] * valid[:, 1] * valid[:, 2]
|
||||||
|
x, y, z = voxel_coords.unbind(dim=1)
|
||||||
|
x, y, z = x[valid], y[valid], z[valid]
|
||||||
|
voxels = torch.zeros(V, V, V, dtype=torch.uint8, device=device)
|
||||||
|
voxels[z, y, x] = 1
|
||||||
|
|
||||||
|
return voxels
|
||||||
|
|
||||||
|
|
||||||
|
def project_verts(verts, P, eps=1e-1):
|
||||||
|
"""
|
||||||
|
Copied from meshrcnn codebase:
|
||||||
|
https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L159
|
||||||
|
|
||||||
|
Project verticies using a 4x4 transformation matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
verts: FloatTensor of shape (N, V, 3) giving a batch of vertex positions or of
|
||||||
|
shape (V, 3) giving a single set of vertex positions.
|
||||||
|
P: FloatTensor of shape (N, 4, 4) giving projection matrices or of shape (4, 4)
|
||||||
|
giving a single projection matrix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
verts_out: FloatTensor of shape (N, V, 3) giving vertex positions (x, y, z)
|
||||||
|
where verts_out[i] is the result of transforming verts[i] by P[i].
|
||||||
|
"""
|
||||||
|
# Handle unbatched inputs
|
||||||
|
singleton = False
|
||||||
|
if verts.dim() == 2:
|
||||||
|
assert P.dim() == 2
|
||||||
|
singleton = True
|
||||||
|
verts, P = verts[None], P[None]
|
||||||
|
|
||||||
|
N, V = verts.shape[0], verts.shape[1]
|
||||||
|
dtype, device = verts.dtype, verts.device
|
||||||
|
|
||||||
|
# Add an extra row of ones to the world-space coordinates of verts before
|
||||||
|
# multiplying by the projection matrix. We could avoid this allocation by
|
||||||
|
# instead multiplying by a 4x3 submatrix of the projectio matrix, then
|
||||||
|
# adding the remaining 4x1 vector. Not sure whether there will be much
|
||||||
|
# performance difference between the two.
|
||||||
|
ones = torch.ones(N, V, 1, dtype=dtype, device=device)
|
||||||
|
verts_hom = torch.cat([verts, ones], dim=2)
|
||||||
|
verts_cam_hom = torch.bmm(verts_hom, P.transpose(1, 2))
|
||||||
|
|
||||||
|
# Avoid division by zero by clamping the absolute value
|
||||||
|
w = verts_cam_hom[:, :, 3:]
|
||||||
|
w_sign = w.sign()
|
||||||
|
w_sign[w == 0] = 1
|
||||||
|
w = w_sign * w.abs().clamp(min=eps)
|
||||||
|
|
||||||
|
verts_proj = verts_cam_hom[:, :, :3] / w
|
||||||
|
|
||||||
|
if singleton:
|
||||||
|
return verts_proj[0]
|
||||||
|
return verts_proj
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderCamera(CamerasBase):
|
||||||
|
"""
|
||||||
|
Camera for rendering objects with calibration matrices from the R2N2 dataset
|
||||||
|
(which uses Blender for rendering the views for each model).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, R=r, T=t, K=k, device="cpu"):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
R: Rotation matrix of shape (N, 3, 3).
|
||||||
|
T: Translation matrix of shape (N, 3).
|
||||||
|
K: Intrinsic matrix of shape (N, 4, 4).
|
||||||
|
device: torch.device or str.
|
||||||
|
"""
|
||||||
|
# The initializer formats all inputs to torch tensors and broadcasts
|
||||||
|
# all the inputs to have the same batch dimension where necessary.
|
||||||
|
super().__init__(device=device, R=R, T=T, K=K)
|
||||||
|
|
||||||
|
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||||
|
transform = Transform3d(device=self.device)
|
||||||
|
transform._matrix = self.K.transpose(1, 2).contiguous() # pyre-ignore[16]
|
||||||
|
return transform
|
||||||
|
|
||||||
|
|
||||||
|
def render_cubified_voxels(
|
||||||
|
voxels: torch.Tensor, shader_type=HardPhongShader, device="cpu", **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Use the Cubify operator to convert inputs voxels to a mesh and then render that mesh.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
voxels: FloatTensor of shape (N, D, D, D) where N is the batch size and
|
||||||
|
D is the number of voxels along each dimension.
|
||||||
|
shader_type: shader_type: shader_type: Shader to use for rendering. Examples
|
||||||
|
include HardPhongShader (default), SoftPhongShader etc or any other type
|
||||||
|
of valid Shader class.
|
||||||
|
device: torch.device on which the tensors should be located.
|
||||||
|
**kwargs: Accepts any of the kwargs that the renderer supports.
|
||||||
|
Returns:
|
||||||
|
Batch of rendered images of shape (N, H, W, 3).
|
||||||
|
"""
|
||||||
|
cubified_voxels = cubify(voxels, CUBIFY_THRESH).to(device)
|
||||||
|
cubified_voxels.textures = TexturesVertex(
|
||||||
|
verts_features=torch.ones_like(cubified_voxels.verts_padded(), device=device)
|
||||||
|
)
|
||||||
|
cameras = BlenderCamera(device=device)
|
||||||
|
renderer = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(
|
||||||
|
cameras=cameras,
|
||||||
|
raster_settings=kwargs.get("raster_settings", RasterizationSettings()),
|
||||||
|
),
|
||||||
|
shader=shader_type(
|
||||||
|
device=device,
|
||||||
|
cameras=cameras,
|
||||||
|
lights=kwargs.get("lights", PointLights()).to(device),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return renderer(cubified_voxels)
|
@ -1,8 +1,6 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
import math
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
|
||||||
from pytorch3d.structures import Meshes
|
from pytorch3d.structures import Meshes
|
||||||
|
|
||||||
|
|
||||||
@ -33,78 +31,4 @@ def collate_batched_meshes(batch: List[Dict]):
|
|||||||
collated_dict["mesh"] = Meshes(
|
collated_dict["mesh"] = Meshes(
|
||||||
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# If collate_batched_meshes receives R2N2 items with images and that
|
|
||||||
# all models have the same number of views V, stack the batches of
|
|
||||||
# views of each model into a new batch of shape (N, V, H, W, 3).
|
|
||||||
# Otherwise leave it as a list.
|
|
||||||
if "images" in collated_dict:
|
|
||||||
try:
|
|
||||||
collated_dict["images"] = torch.stack(collated_dict["images"])
|
|
||||||
except RuntimeError:
|
|
||||||
print(
|
|
||||||
"Models don't have the same number of views. Now returning "
|
|
||||||
"lists of images instead of batches."
|
|
||||||
)
|
|
||||||
|
|
||||||
# If collate_batched_meshes receives R2N2 items with camera calibration
|
|
||||||
# matrices and that all models have the same number of views V, stack each
|
|
||||||
# type of matrices into a new batch of shape (N, V, ...).
|
|
||||||
# Otherwise leave them as lists.
|
|
||||||
if all(x in collated_dict for x in ["R", "T", "K"]):
|
|
||||||
try:
|
|
||||||
collated_dict["R"] = torch.stack(collated_dict["R"]) # (N, V, 3, 3)
|
|
||||||
collated_dict["T"] = torch.stack(collated_dict["T"]) # (N, V, 3)
|
|
||||||
collated_dict["K"] = torch.stack(collated_dict["K"]) # (N, V, 4, 4)
|
|
||||||
except RuntimeError:
|
|
||||||
print(
|
|
||||||
"Models don't have the same number of views. Now returning "
|
|
||||||
"lists of calibration matrices instead of batches."
|
|
||||||
)
|
|
||||||
|
|
||||||
return collated_dict
|
return collated_dict
|
||||||
|
|
||||||
|
|
||||||
def compute_extrinsic_matrix(azimuth, elevation, distance):
|
|
||||||
"""
|
|
||||||
Copied from meshrcnn codebase:
|
|
||||||
https://github.com/facebookresearch/meshrcnn/blob/master/shapenet/utils/coords.py#L96
|
|
||||||
|
|
||||||
Compute 4x4 extrinsic matrix that converts from homogenous world coordinates
|
|
||||||
to homogenous camera coordinates. We assume that the camera is looking at the
|
|
||||||
origin.
|
|
||||||
Used in R2N2 Dataset when computing calibration matrices.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
azimuth: Rotation about the z-axis, in degrees.
|
|
||||||
elevation: Rotation above the xy-plane, in degrees.
|
|
||||||
distance: Distance from the origin.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
FloatTensor of shape (4, 4).
|
|
||||||
"""
|
|
||||||
azimuth, elevation, distance = float(azimuth), float(elevation), float(distance)
|
|
||||||
|
|
||||||
az_rad = -math.pi * azimuth / 180.0
|
|
||||||
el_rad = -math.pi * elevation / 180.0
|
|
||||||
sa = math.sin(az_rad)
|
|
||||||
ca = math.cos(az_rad)
|
|
||||||
se = math.sin(el_rad)
|
|
||||||
ce = math.cos(el_rad)
|
|
||||||
R_world2obj = torch.tensor(
|
|
||||||
[[ca * ce, sa * ce, -se], [-sa, ca, 0], [ca * se, sa * se, ce]]
|
|
||||||
)
|
|
||||||
R_obj2cam = torch.tensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
|
||||||
R_world2cam = R_obj2cam.mm(R_world2obj)
|
|
||||||
cam_location = torch.tensor([[distance, 0, 0]]).t()
|
|
||||||
T_world2cam = -(R_obj2cam.mm(cam_location))
|
|
||||||
RT = torch.cat([R_world2cam, T_world2cam], dim=1)
|
|
||||||
RT = torch.cat([RT, torch.tensor([[0.0, 0, 0, 1]])])
|
|
||||||
|
|
||||||
# Georgia: For some reason I cannot fathom, when Blender loads a .obj file it
|
|
||||||
# rotates the model 90 degrees about the x axis. To compensate for this quirk we
|
|
||||||
# roll that rotation into the extrinsic matrix here
|
|
||||||
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
|
||||||
RT = RT.mm(rot.to(RT))
|
|
||||||
|
|
||||||
return RT
|
|
||||||
|
BIN
tests/data/test_r2n2_voxel_to_mesh_render.png
Normal file
BIN
tests/data/test_r2n2_voxel_to_mesh_render.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 16 KiB |
@ -11,7 +11,12 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin, load_rgb_image
|
from common_testing import TestCaseMixin, load_rgb_image
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.datasets import R2N2, BlenderCamera, collate_batched_meshes
|
from pytorch3d.datasets import (
|
||||||
|
R2N2,
|
||||||
|
BlenderCamera,
|
||||||
|
collate_batched_R2N2,
|
||||||
|
render_cubified_voxels,
|
||||||
|
)
|
||||||
from pytorch3d.renderer import (
|
from pytorch3d.renderer import (
|
||||||
OpenGLPerspectiveCameras,
|
OpenGLPerspectiveCameras,
|
||||||
PointLights,
|
PointLights,
|
||||||
@ -62,8 +67,10 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
|
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
|
||||||
check the first image returned is correct.
|
check the first image returned is correct.
|
||||||
"""
|
"""
|
||||||
# Load dataset in the test split.
|
# Load dataset in the train split.
|
||||||
r2n2_dataset = R2N2("test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
r2n2_dataset = R2N2(
|
||||||
|
"test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
|
||||||
|
)
|
||||||
|
|
||||||
# Check total number of objects in the dataset is correct.
|
# Check total number of objects in the dataset is correct.
|
||||||
with open(SPLITS_PATH) as splits:
|
with open(SPLITS_PATH) as splits:
|
||||||
@ -114,6 +121,10 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(r2n2_dataset[635]["images"].shape[0], 5)
|
self.assertEqual(r2n2_dataset[635]["images"].shape[0], 5)
|
||||||
self.assertEqual(r2n2_dataset[8369]["images"].shape[0], 10)
|
self.assertEqual(r2n2_dataset[8369]["images"].shape[0], 10)
|
||||||
|
|
||||||
|
# Check that the voxel tensor returned by __getitem__ has the correct shape.
|
||||||
|
self.assertEqual(r2n2_obj["voxels"].ndim, 4)
|
||||||
|
self.assertEqual(r2n2_obj["voxels"].shape, (24, 128, 128, 128))
|
||||||
|
|
||||||
def test_collate_models(self):
|
def test_collate_models(self):
|
||||||
"""
|
"""
|
||||||
Test collate_batched_meshes returns items of the correct shapes and types.
|
Test collate_batched_meshes returns items of the correct shapes and types.
|
||||||
@ -121,10 +132,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
the correct shapes and types are returned.
|
the correct shapes and types are returned.
|
||||||
"""
|
"""
|
||||||
# Load dataset in the train split.
|
# Load dataset in the train split.
|
||||||
r2n2_dataset = R2N2("val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
r2n2_dataset = R2N2(
|
||||||
|
"val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
|
||||||
|
)
|
||||||
|
|
||||||
# Randomly retrieve several objects from the dataset and collate them.
|
# Randomly retrieve several objects from the dataset and collate them.
|
||||||
collated_meshes = collate_batched_meshes(
|
collated_meshes = collate_batched_R2N2(
|
||||||
[r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))]
|
[r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))]
|
||||||
)
|
)
|
||||||
# Check the collated verts and faces have the correct shapes.
|
# Check the collated verts and faces have the correct shapes.
|
||||||
@ -145,7 +158,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
# in batch have the correct shape.
|
# in batch have the correct shape.
|
||||||
batch_size = 12
|
batch_size = 12
|
||||||
r2n2_loader = DataLoader(
|
r2n2_loader = DataLoader(
|
||||||
r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes
|
r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_R2N2
|
||||||
)
|
)
|
||||||
it = iter(r2n2_loader)
|
it = iter(r2n2_loader)
|
||||||
object_batch = next(it)
|
object_batch = next(it)
|
||||||
@ -158,6 +171,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(object_batch["R"].shape[0], batch_size)
|
self.assertEqual(object_batch["R"].shape[0], batch_size)
|
||||||
self.assertEqual(object_batch["T"].shape[0], batch_size)
|
self.assertEqual(object_batch["T"].shape[0], batch_size)
|
||||||
self.assertEqual(object_batch["K"].shape[0], batch_size)
|
self.assertEqual(object_batch["K"].shape[0], batch_size)
|
||||||
|
self.assertEqual(len(object_batch["voxels"]), batch_size)
|
||||||
|
|
||||||
def test_catch_render_arg_errors(self):
|
def test_catch_render_arg_errors(self):
|
||||||
"""
|
"""
|
||||||
@ -338,3 +352,24 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
"test_r2n2_render_with_blender_calibrations_%s.png" % idx, DATA_DIR
|
"test_r2n2_render_with_blender_calibrations_%s.png" % idx, DATA_DIR
|
||||||
)
|
)
|
||||||
self.assertClose(r2n2_batch_rgb, image_ref, atol=0.05)
|
self.assertClose(r2n2_batch_rgb, image_ref, atol=0.05)
|
||||||
|
|
||||||
|
def test_render_voxels(self):
|
||||||
|
"""
|
||||||
|
Test rendering meshes formed from voxels.
|
||||||
|
"""
|
||||||
|
# Set up device and seed for random selections.
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
# Load dataset in the train split with only a single view returned for each model.
|
||||||
|
r2n2_dataset = R2N2(
|
||||||
|
"train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
|
||||||
|
)
|
||||||
|
r2n2_model = r2n2_dataset[6, [5]]
|
||||||
|
vox_render = render_cubified_voxels(r2n2_model["voxels"], device=device)
|
||||||
|
vox_render_rgb = vox_render[0, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((vox_render_rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / ("DEBUG_r2n2_voxel_to_mesh_render.png")
|
||||||
|
)
|
||||||
|
image_ref = load_rgb_image("test_r2n2_voxel_to_mesh_render.png", DATA_DIR)
|
||||||
|
self.assertClose(vox_render_rgb, image_ref, atol=0.05)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user