From 778383eef77a23686f3d0e68834b29d6d73f8501 Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Fri, 21 Aug 2020 20:41:07 -0700 Subject: [PATCH] Texture loading and rendering in ShapeNetCore and R2N2 data loaders Summary: - Add support for loading textures from ShapeNet Obj files as a texture atlas. - Support textured rendering of shapenet models Reviewed By: gkioxari Differential Revision: D23141143 fbshipit-source-id: 26eb81758d4cdbd6d820b072b58f5c6c08cb90bc --- pytorch3d/datasets/r2n2/r2n2.py | 44 ++++++++++++----- pytorch3d/datasets/shapenet/shapenet_core.py | 25 ++++++++-- pytorch3d/datasets/shapenet_base.py | 50 ++++++++++++++------ pytorch3d/datasets/utils.py | 11 ++++- pytorch3d/io/obj_io.py | 26 +++++----- tests/test_r2n2.py | 23 +++++++-- tests/test_shapenet_core.py | 3 +- 7 files changed, 134 insertions(+), 48 deletions(-) diff --git a/pytorch3d/datasets/r2n2/r2n2.py b/pytorch3d/datasets/r2n2/r2n2.py index 3abe4927..b4ef58c0 100644 --- a/pytorch3d/datasets/r2n2/r2n2.py +++ b/pytorch3d/datasets/r2n2/r2n2.py @@ -10,7 +10,6 @@ import numpy as np import torch from PIL import Image from pytorch3d.datasets.shapenet_base import ShapeNetBase -from pytorch3d.io import load_obj from pytorch3d.renderer import HardPhongShader from tabulate import tabulate @@ -45,6 +44,7 @@ class R2N2(ShapeNetBase): dataset. The R2N2 dataset also contains its own 24 renderings of each object and voxelized models. Most of the models have all 24 views in the same split, but there are eight of them that divide their views between train and test splits. + """ def __init__( @@ -55,6 +55,10 @@ class R2N2(ShapeNetBase): splits_file, return_all_views: bool = True, return_voxels: bool = False, + views_rel_path: str = "ShapeNetRendering", + voxels_rel_path: str = "ShapeNetVoxels", + load_textures: bool = True, + texture_resolution: int = 4, ): """ Store each object's synset id and models id the given directories. @@ -69,10 +73,24 @@ class R2N2(ShapeNetBase): selected and loaded. return_voxels(bool): Indicator of whether or not to return voxels as a tensor of shape (D, D, D) where D is the number of voxels along each dimension. + views_rel_path: path to rendered views within the r2n2_dir. If not specified, + the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetRendering"). + voxels_rel_path: path to rendered views within the r2n2_dir. If not specified, + the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetVoxels"). + load_textures: Boolean indicating whether textures should loaded for the model. + Textures will be of type TexturesAtlas i.e. a texture map per face. + texture_resolution: Int specifying the resolution of the texture map per face + created using the textures in the obj file. A + (texture_resolution, texture_resolution, 3) map is created per face. + """ super().__init__() self.shapenet_dir = shapenet_dir self.r2n2_dir = r2n2_dir + self.views_rel_path = views_rel_path + self.voxels_rel_path = voxels_rel_path + self.load_textures = load_textures + self.texture_resolution = texture_resolution # Examine if split is valid. if split not in ["train", "val", "test"]: raise ValueError("split has to be one of (train, val, test).") @@ -90,22 +108,22 @@ class R2N2(ShapeNetBase): self.return_images = True # Check if the folder containing R2N2 renderings is included in r2n2_dir. - if not path.isdir(path.join(r2n2_dir, "ShapeNetRendering")): + if not path.isdir(path.join(r2n2_dir, views_rel_path)): self.return_images = False msg = ( - "ShapeNetRendering not found in %s. R2N2 renderings will " + "%s not found in %s. R2N2 renderings will " "be skipped when returning models." - ) % (r2n2_dir) + ) % (views_rel_path, r2n2_dir) warnings.warn(msg) self.return_voxels = return_voxels # Check if the folder containing voxel coordinates is included in r2n2_dir. - if not path.isdir(path.join(r2n2_dir, "ShapeNetVox32")): + if not path.isdir(path.join(r2n2_dir, voxels_rel_path)): self.return_voxels = False msg = ( - "ShapeNetVox32 not found in %s. Voxel coordinates will " + "%s not found in %s. Voxel coordinates will " "be skipped when returning models." - ) % (r2n2_dir) + ) % (voxels_rel_path, r2n2_dir) warnings.warn(msg) synset_set = set() @@ -230,8 +248,11 @@ class R2N2(ShapeNetBase): model_path = path.join( self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj" ) - model["verts"], faces, _ = load_obj(model_path) - model["faces"] = faces.verts_idx + + verts, faces, textures = self._load_mesh(model_path) + model["verts"] = verts + model["faces"] = faces + model["textures"] = textures model["label"] = self.synset_dict[model["synset_id"]] model["images"] = None @@ -240,7 +261,7 @@ class R2N2(ShapeNetBase): if self.return_images: rendering_path = path.join( self.r2n2_dir, - "ShapeNetRendering", + self.views_rel_path, model["synset_id"], model["model_id"], "rendering", @@ -284,10 +305,11 @@ class R2N2(ShapeNetBase): model["K"] = K.expand(len(model_views), 4, 4) voxels_list = [] + # Read voxels if required. voxel_path = path.join( self.r2n2_dir, - "ShapeNetVox32", + self.voxels_rel_path, model["synset_id"], model["model_id"], "model.binvox", diff --git a/pytorch3d/datasets/shapenet/shapenet_core.py b/pytorch3d/datasets/shapenet/shapenet_core.py index 0d799ca2..78184be8 100644 --- a/pytorch3d/datasets/shapenet/shapenet_core.py +++ b/pytorch3d/datasets/shapenet/shapenet_core.py @@ -8,7 +8,6 @@ from pathlib import Path from typing import Dict from pytorch3d.datasets.shapenet_base import ShapeNetBase -from pytorch3d.io import load_obj SYNSET_DICT_DIR = Path(__file__).resolve().parent @@ -21,7 +20,14 @@ class ShapeNetCore(ShapeNetBase): https://www.shapenet.org/. """ - def __init__(self, data_dir, synsets=None, version: int = 1): + def __init__( + self, + data_dir, + synsets=None, + version: int = 1, + load_textures: bool = True, + texture_resolution: int = 4, + ): """ Store each object's synset id and models id from data_dir. @@ -38,10 +44,17 @@ class ShapeNetCore(ShapeNetBase): respectively. You can combine the categories manually if needed. Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to version 1. - + load_textures: Boolean indicating whether textures should loaded for the model. + Textures will be of type TexturesAtlas i.e. a texture map per face. + texture_resolution: Int specifying the resolution of the texture map per face + created using the textures in the obj file. A + (texture_resolution, texture_resolution, 3) map is created per face. """ super().__init__() self.shapenet_dir = data_dir + self.load_textures = load_textures + self.texture_resolution = texture_resolution + if version not in [1, 2]: raise ValueError("Version number must be either 1 or 2.") self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj" @@ -133,7 +146,9 @@ class ShapeNetCore(ShapeNetBase): model_path = path.join( self.shapenet_dir, model["synset_id"], model["model_id"], self.model_dir ) - model["verts"], faces, _ = load_obj(model_path) - model["faces"] = faces.verts_idx + verts, faces, textures = self._load_mesh(model_path) + model["verts"] = verts + model["faces"] = faces + model["textures"] = textures model["label"] = self.synset_dict[model["synset_id"]] return model diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index 2de233d0..32d5c33b 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -1,11 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import warnings -from os import path -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch -from pytorch3d.io import load_objs_as_meshes +from pytorch3d.io import load_obj from pytorch3d.renderer import ( FoVPerspectiveCameras, HardPhongShader, @@ -16,6 +15,8 @@ from pytorch3d.renderer import ( TexturesVertex, ) +from .utils import collate_batched_meshes + class ShapeNetBase(torch.utils.data.Dataset): """ @@ -35,6 +36,8 @@ class ShapeNetBase(torch.utils.data.Dataset): self.synset_num_models = {} self.shapenet_dir = "" self.model_dir = "model.obj" + self.load_textures = True + self.texture_resolution = 4 def __len__(self): """ @@ -74,6 +77,27 @@ class ShapeNetBase(torch.utils.data.Dataset): model["model_id"] = self.model_ids[idx] return model + def _load_mesh(self, model_path) -> Tuple: + verts, faces, aux = load_obj( + model_path, + create_texture_atlas=self.load_textures, + load_textures=self.load_textures, + texture_atlas_size=self.texture_resolution, + ) + if self.load_textures: + textures = aux.texture_atlas + # Some meshes don't have textures. In this case + # create a white texture map + if textures is None: + textures = verts.new_ones( + faces.verts_idx.shape[0], + self.texture_resolution, + self.texture_resolution, + 3, + ) + + return verts, faces.verts_idx, textures + def render( self, model_ids: Optional[List[str]] = None, @@ -112,19 +136,15 @@ class ShapeNetBase(torch.utils.data.Dataset): Batch of rendered images of shape (N, H, W, 3). """ idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs) - paths = [ - path.join( - self.shapenet_dir, - self.synset_ids[idx], - self.model_ids[idx], - self.model_dir, + # Use the getitem method which loads mesh + texture + models = [self[idx] for idx in idxs] + meshes = collate_batched_meshes(models)["mesh"] + if meshes.textures is None: + meshes.textures = TexturesVertex( + verts_features=torch.ones_like(meshes.verts_padded(), device=device) ) - for idx in idxs - ] - meshes = load_objs_as_meshes(paths, device=device, load_textures=False) - meshes.textures = TexturesVertex( - verts_features=torch.ones_like(meshes.verts_padded(), device=device) - ) + + meshes = meshes.to(device) cameras = kwargs.get("cameras", FoVPerspectiveCameras()).to(device) if len(cameras) != 1 and len(cameras) % len(meshes) != 0: raise ValueError("Mismatch between batch dims of cameras and meshes.") diff --git a/pytorch3d/datasets/utils.py b/pytorch3d/datasets/utils.py index 5d8dd2ae..029435c1 100644 --- a/pytorch3d/datasets/utils.py +++ b/pytorch3d/datasets/utils.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from typing import Dict, List +from pytorch3d.renderer.mesh import TexturesAtlas from pytorch3d.structures import Meshes @@ -28,7 +29,15 @@ def collate_batched_meshes(batch: List[Dict]): collated_dict["mesh"] = None if {"verts", "faces"}.issubset(collated_dict.keys()): + + textures = None + if "textures" in collated_dict: + textures = TexturesAtlas(atlas=collated_dict["textures"]) + collated_dict["mesh"] = Meshes( - verts=collated_dict["verts"], faces=collated_dict["faces"] + verts=collated_dict["verts"], + faces=collated_dict["faces"], + textures=textures, ) + return collated_dict diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index a0e7848a..e87255c3 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -253,7 +253,7 @@ def load_objs_as_meshes( tex = None if create_texture_atlas: # TexturesAtlas type - tex = TexturesAtlas(atlas=[aux.texture_atlas]) + tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)]) else: # TexturesUV type tex_maps = aux.texture_images @@ -477,18 +477,20 @@ def _load_obj( face_material_names = np.array(material_names)[idx] # (F,) face_material_names[idx == -1] = "" - # Get the uv coords for each vert in each face - faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2) + texture_atlas = None + if len(verts_uvs) > 0: + # Get the uv coords for each vert in each face + faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2) - # Construct the atlas. - texture_atlas = make_mesh_texture_atlas( - material_colors, - texture_images, - face_material_names, - faces_verts_uvs, - texture_atlas_size, - texture_wrap, - ) + # Construct the atlas. + texture_atlas = make_mesh_texture_atlas( + material_colors, + texture_images, + face_material_names, + faces_verts_uvs, + texture_atlas_size, + texture_wrap, + ) else: warnings.warn(f"Mtl file does not exist: {f_mtl}") elif len(material_names) > 0: diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index 921f9b11..f1ed219e 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -33,6 +33,8 @@ from torch.utils.data import DataLoader R2N2_PATH = None SHAPENET_PATH = None SPLITS_PATH = None +VOXELS_REL_PATH = "ShapeNetVox" + DEBUG = False DATA_DIR = Path(__file__).resolve().parent / "data" @@ -69,7 +71,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): """ # Load dataset in the train split. r2n2_dataset = R2N2( - "test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True + "test", + SHAPENET_PATH, + R2N2_PATH, + SPLITS_PATH, + return_voxels=True, + voxels_rel_path=VOXELS_REL_PATH, ) # Check total number of objects in the dataset is correct. @@ -133,7 +140,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): """ # Load dataset in the train split. r2n2_dataset = R2N2( - "val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True + "val", + SHAPENET_PATH, + R2N2_PATH, + SPLITS_PATH, + return_voxels=True, + voxels_rel_path=VOXELS_REL_PATH, ) # Randomly retrieve several objects from the dataset and collate them. @@ -362,7 +374,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): # Load dataset in the train split with only a single view returned for each model. r2n2_dataset = R2N2( - "train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True + "train", + SHAPENET_PATH, + R2N2_PATH, + SPLITS_PATH, + return_voxels=True, + voxels_rel_path=VOXELS_REL_PATH, ) r2n2_model = r2n2_dataset[6, [5]] vox_render = render_cubified_voxels(r2n2_model["voxels"], device=device) diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py index b974a700..d832fce3 100644 --- a/tests/test_shapenet_core.py +++ b/tests/test_shapenet_core.py @@ -22,6 +22,7 @@ from torch.utils.data import DataLoader # Set the SHAPENET_PATH to the local path to the dataset SHAPENET_PATH = None +VERSION = 1 # If DEBUG=True, save out images generated in the tests for debugging. # All saved images have prefix DEBUG_ DEBUG = False @@ -55,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): self.assertTrue("Version number must be either 1 or 2." in str(err.exception)) # Load ShapeNetCore without specifying any particular categories. - shapenet_dataset = ShapeNetCore(SHAPENET_PATH) + shapenet_dataset = ShapeNetCore(SHAPENET_PATH, version=VERSION) # Count the number of grandchildren directories (which should be equal to # the total number of objects in the dataset) by walking through the given