Texture loading and rendering in ShapeNetCore and R2N2 data loaders

Summary:
- Add support for loading textures from ShapeNet Obj files as a texture atlas.
- Support textured rendering of shapenet models

Reviewed By: gkioxari

Differential Revision: D23141143

fbshipit-source-id: 26eb81758d4cdbd6d820b072b58f5c6c08cb90bc
This commit is contained in:
Nikhila Ravi 2020-08-21 20:41:07 -07:00 committed by Facebook GitHub Bot
parent 90f6a005b0
commit 778383eef7
7 changed files with 134 additions and 48 deletions

View File

@ -10,7 +10,6 @@ import numpy as np
import torch
from PIL import Image
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.io import load_obj
from pytorch3d.renderer import HardPhongShader
from tabulate import tabulate
@ -45,6 +44,7 @@ class R2N2(ShapeNetBase):
dataset. The R2N2 dataset also contains its own 24 renderings of each object and
voxelized models. Most of the models have all 24 views in the same split, but there
are eight of them that divide their views between train and test splits.
"""
def __init__(
@ -55,6 +55,10 @@ class R2N2(ShapeNetBase):
splits_file,
return_all_views: bool = True,
return_voxels: bool = False,
views_rel_path: str = "ShapeNetRendering",
voxels_rel_path: str = "ShapeNetVoxels",
load_textures: bool = True,
texture_resolution: int = 4,
):
"""
Store each object's synset id and models id the given directories.
@ -69,10 +73,24 @@ class R2N2(ShapeNetBase):
selected and loaded.
return_voxels(bool): Indicator of whether or not to return voxels as a tensor
of shape (D, D, D) where D is the number of voxels along each dimension.
views_rel_path: path to rendered views within the r2n2_dir. If not specified,
the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetRendering").
voxels_rel_path: path to rendered views within the r2n2_dir. If not specified,
the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetVoxels").
load_textures: Boolean indicating whether textures should loaded for the model.
Textures will be of type TexturesAtlas i.e. a texture map per face.
texture_resolution: Int specifying the resolution of the texture map per face
created using the textures in the obj file. A
(texture_resolution, texture_resolution, 3) map is created per face.
"""
super().__init__()
self.shapenet_dir = shapenet_dir
self.r2n2_dir = r2n2_dir
self.views_rel_path = views_rel_path
self.voxels_rel_path = voxels_rel_path
self.load_textures = load_textures
self.texture_resolution = texture_resolution
# Examine if split is valid.
if split not in ["train", "val", "test"]:
raise ValueError("split has to be one of (train, val, test).")
@ -90,22 +108,22 @@ class R2N2(ShapeNetBase):
self.return_images = True
# Check if the folder containing R2N2 renderings is included in r2n2_dir.
if not path.isdir(path.join(r2n2_dir, "ShapeNetRendering")):
if not path.isdir(path.join(r2n2_dir, views_rel_path)):
self.return_images = False
msg = (
"ShapeNetRendering not found in %s. R2N2 renderings will "
"%s not found in %s. R2N2 renderings will "
"be skipped when returning models."
) % (r2n2_dir)
) % (views_rel_path, r2n2_dir)
warnings.warn(msg)
self.return_voxels = return_voxels
# Check if the folder containing voxel coordinates is included in r2n2_dir.
if not path.isdir(path.join(r2n2_dir, "ShapeNetVox32")):
if not path.isdir(path.join(r2n2_dir, voxels_rel_path)):
self.return_voxels = False
msg = (
"ShapeNetVox32 not found in %s. Voxel coordinates will "
"%s not found in %s. Voxel coordinates will "
"be skipped when returning models."
) % (r2n2_dir)
) % (voxels_rel_path, r2n2_dir)
warnings.warn(msg)
synset_set = set()
@ -230,8 +248,11 @@ class R2N2(ShapeNetBase):
model_path = path.join(
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
)
model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx
verts, faces, textures = self._load_mesh(model_path)
model["verts"] = verts
model["faces"] = faces
model["textures"] = textures
model["label"] = self.synset_dict[model["synset_id"]]
model["images"] = None
@ -240,7 +261,7 @@ class R2N2(ShapeNetBase):
if self.return_images:
rendering_path = path.join(
self.r2n2_dir,
"ShapeNetRendering",
self.views_rel_path,
model["synset_id"],
model["model_id"],
"rendering",
@ -284,10 +305,11 @@ class R2N2(ShapeNetBase):
model["K"] = K.expand(len(model_views), 4, 4)
voxels_list = []
# Read voxels if required.
voxel_path = path.join(
self.r2n2_dir,
"ShapeNetVox32",
self.voxels_rel_path,
model["synset_id"],
model["model_id"],
"model.binvox",

View File

@ -8,7 +8,6 @@ from pathlib import Path
from typing import Dict
from pytorch3d.datasets.shapenet_base import ShapeNetBase
from pytorch3d.io import load_obj
SYNSET_DICT_DIR = Path(__file__).resolve().parent
@ -21,7 +20,14 @@ class ShapeNetCore(ShapeNetBase):
https://www.shapenet.org/.
"""
def __init__(self, data_dir, synsets=None, version: int = 1):
def __init__(
self,
data_dir,
synsets=None,
version: int = 1,
load_textures: bool = True,
texture_resolution: int = 4,
):
"""
Store each object's synset id and models id from data_dir.
@ -38,10 +44,17 @@ class ShapeNetCore(ShapeNetBase):
respectively. You can combine the categories manually if needed.
Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to
version 1.
load_textures: Boolean indicating whether textures should loaded for the model.
Textures will be of type TexturesAtlas i.e. a texture map per face.
texture_resolution: Int specifying the resolution of the texture map per face
created using the textures in the obj file. A
(texture_resolution, texture_resolution, 3) map is created per face.
"""
super().__init__()
self.shapenet_dir = data_dir
self.load_textures = load_textures
self.texture_resolution = texture_resolution
if version not in [1, 2]:
raise ValueError("Version number must be either 1 or 2.")
self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj"
@ -133,7 +146,9 @@ class ShapeNetCore(ShapeNetBase):
model_path = path.join(
self.shapenet_dir, model["synset_id"], model["model_id"], self.model_dir
)
model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx
verts, faces, textures = self._load_mesh(model_path)
model["verts"] = verts
model["faces"] = faces
model["textures"] = textures
model["label"] = self.synset_dict[model["synset_id"]]
return model

View File

@ -1,11 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from os import path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import torch
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.io import load_obj
from pytorch3d.renderer import (
FoVPerspectiveCameras,
HardPhongShader,
@ -16,6 +15,8 @@ from pytorch3d.renderer import (
TexturesVertex,
)
from .utils import collate_batched_meshes
class ShapeNetBase(torch.utils.data.Dataset):
"""
@ -35,6 +36,8 @@ class ShapeNetBase(torch.utils.data.Dataset):
self.synset_num_models = {}
self.shapenet_dir = ""
self.model_dir = "model.obj"
self.load_textures = True
self.texture_resolution = 4
def __len__(self):
"""
@ -74,6 +77,27 @@ class ShapeNetBase(torch.utils.data.Dataset):
model["model_id"] = self.model_ids[idx]
return model
def _load_mesh(self, model_path) -> Tuple:
verts, faces, aux = load_obj(
model_path,
create_texture_atlas=self.load_textures,
load_textures=self.load_textures,
texture_atlas_size=self.texture_resolution,
)
if self.load_textures:
textures = aux.texture_atlas
# Some meshes don't have textures. In this case
# create a white texture map
if textures is None:
textures = verts.new_ones(
faces.verts_idx.shape[0],
self.texture_resolution,
self.texture_resolution,
3,
)
return verts, faces.verts_idx, textures
def render(
self,
model_ids: Optional[List[str]] = None,
@ -112,19 +136,15 @@ class ShapeNetBase(torch.utils.data.Dataset):
Batch of rendered images of shape (N, H, W, 3).
"""
idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
paths = [
path.join(
self.shapenet_dir,
self.synset_ids[idx],
self.model_ids[idx],
self.model_dir,
# Use the getitem method which loads mesh + texture
models = [self[idx] for idx in idxs]
meshes = collate_batched_meshes(models)["mesh"]
if meshes.textures is None:
meshes.textures = TexturesVertex(
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
)
for idx in idxs
]
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
meshes.textures = TexturesVertex(
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
)
meshes = meshes.to(device)
cameras = kwargs.get("cameras", FoVPerspectiveCameras()).to(device)
if len(cameras) != 1 and len(cameras) % len(meshes) != 0:
raise ValueError("Mismatch between batch dims of cameras and meshes.")

View File

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Dict, List
from pytorch3d.renderer.mesh import TexturesAtlas
from pytorch3d.structures import Meshes
@ -28,7 +29,15 @@ def collate_batched_meshes(batch: List[Dict]):
collated_dict["mesh"] = None
if {"verts", "faces"}.issubset(collated_dict.keys()):
textures = None
if "textures" in collated_dict:
textures = TexturesAtlas(atlas=collated_dict["textures"])
collated_dict["mesh"] = Meshes(
verts=collated_dict["verts"], faces=collated_dict["faces"]
verts=collated_dict["verts"],
faces=collated_dict["faces"],
textures=textures,
)
return collated_dict

View File

@ -253,7 +253,7 @@ def load_objs_as_meshes(
tex = None
if create_texture_atlas:
# TexturesAtlas type
tex = TexturesAtlas(atlas=[aux.texture_atlas])
tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)])
else:
# TexturesUV type
tex_maps = aux.texture_images
@ -477,18 +477,20 @@ def _load_obj(
face_material_names = np.array(material_names)[idx] # (F,)
face_material_names[idx == -1] = ""
# Get the uv coords for each vert in each face
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
texture_atlas = None
if len(verts_uvs) > 0:
# Get the uv coords for each vert in each face
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
# Construct the atlas.
texture_atlas = make_mesh_texture_atlas(
material_colors,
texture_images,
face_material_names,
faces_verts_uvs,
texture_atlas_size,
texture_wrap,
)
# Construct the atlas.
texture_atlas = make_mesh_texture_atlas(
material_colors,
texture_images,
face_material_names,
faces_verts_uvs,
texture_atlas_size,
texture_wrap,
)
else:
warnings.warn(f"Mtl file does not exist: {f_mtl}")
elif len(material_names) > 0:

View File

@ -33,6 +33,8 @@ from torch.utils.data import DataLoader
R2N2_PATH = None
SHAPENET_PATH = None
SPLITS_PATH = None
VOXELS_REL_PATH = "ShapeNetVox"
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
@ -69,7 +71,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
"""
# Load dataset in the train split.
r2n2_dataset = R2N2(
"test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
"test",
SHAPENET_PATH,
R2N2_PATH,
SPLITS_PATH,
return_voxels=True,
voxels_rel_path=VOXELS_REL_PATH,
)
# Check total number of objects in the dataset is correct.
@ -133,7 +140,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
"""
# Load dataset in the train split.
r2n2_dataset = R2N2(
"val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
"val",
SHAPENET_PATH,
R2N2_PATH,
SPLITS_PATH,
return_voxels=True,
voxels_rel_path=VOXELS_REL_PATH,
)
# Randomly retrieve several objects from the dataset and collate them.
@ -362,7 +374,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
# Load dataset in the train split with only a single view returned for each model.
r2n2_dataset = R2N2(
"train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
"train",
SHAPENET_PATH,
R2N2_PATH,
SPLITS_PATH,
return_voxels=True,
voxels_rel_path=VOXELS_REL_PATH,
)
r2n2_model = r2n2_dataset[6, [5]]
vox_render = render_cubified_voxels(r2n2_model["voxels"], device=device)

View File

@ -22,6 +22,7 @@ from torch.utils.data import DataLoader
# Set the SHAPENET_PATH to the local path to the dataset
SHAPENET_PATH = None
VERSION = 1
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
@ -55,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
# Load ShapeNetCore without specifying any particular categories.
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
shapenet_dataset = ShapeNetCore(SHAPENET_PATH, version=VERSION)
# Count the number of grandchildren directories (which should be equal to
# the total number of objects in the dataset) by walking through the given