mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Texture loading and rendering in ShapeNetCore and R2N2 data loaders
Summary: - Add support for loading textures from ShapeNet Obj files as a texture atlas. - Support textured rendering of shapenet models Reviewed By: gkioxari Differential Revision: D23141143 fbshipit-source-id: 26eb81758d4cdbd6d820b072b58f5c6c08cb90bc
This commit is contained in:
parent
90f6a005b0
commit
778383eef7
@ -10,7 +10,6 @@ import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.renderer import HardPhongShader
|
||||
from tabulate import tabulate
|
||||
|
||||
@ -45,6 +44,7 @@ class R2N2(ShapeNetBase):
|
||||
dataset. The R2N2 dataset also contains its own 24 renderings of each object and
|
||||
voxelized models. Most of the models have all 24 views in the same split, but there
|
||||
are eight of them that divide their views between train and test splits.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -55,6 +55,10 @@ class R2N2(ShapeNetBase):
|
||||
splits_file,
|
||||
return_all_views: bool = True,
|
||||
return_voxels: bool = False,
|
||||
views_rel_path: str = "ShapeNetRendering",
|
||||
voxels_rel_path: str = "ShapeNetVoxels",
|
||||
load_textures: bool = True,
|
||||
texture_resolution: int = 4,
|
||||
):
|
||||
"""
|
||||
Store each object's synset id and models id the given directories.
|
||||
@ -69,10 +73,24 @@ class R2N2(ShapeNetBase):
|
||||
selected and loaded.
|
||||
return_voxels(bool): Indicator of whether or not to return voxels as a tensor
|
||||
of shape (D, D, D) where D is the number of voxels along each dimension.
|
||||
views_rel_path: path to rendered views within the r2n2_dir. If not specified,
|
||||
the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetRendering").
|
||||
voxels_rel_path: path to rendered views within the r2n2_dir. If not specified,
|
||||
the renderings are assumed to be at os.path.join(rn2n_dir, "ShapeNetVoxels").
|
||||
load_textures: Boolean indicating whether textures should loaded for the model.
|
||||
Textures will be of type TexturesAtlas i.e. a texture map per face.
|
||||
texture_resolution: Int specifying the resolution of the texture map per face
|
||||
created using the textures in the obj file. A
|
||||
(texture_resolution, texture_resolution, 3) map is created per face.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapenet_dir = shapenet_dir
|
||||
self.r2n2_dir = r2n2_dir
|
||||
self.views_rel_path = views_rel_path
|
||||
self.voxels_rel_path = voxels_rel_path
|
||||
self.load_textures = load_textures
|
||||
self.texture_resolution = texture_resolution
|
||||
# Examine if split is valid.
|
||||
if split not in ["train", "val", "test"]:
|
||||
raise ValueError("split has to be one of (train, val, test).")
|
||||
@ -90,22 +108,22 @@ class R2N2(ShapeNetBase):
|
||||
|
||||
self.return_images = True
|
||||
# Check if the folder containing R2N2 renderings is included in r2n2_dir.
|
||||
if not path.isdir(path.join(r2n2_dir, "ShapeNetRendering")):
|
||||
if not path.isdir(path.join(r2n2_dir, views_rel_path)):
|
||||
self.return_images = False
|
||||
msg = (
|
||||
"ShapeNetRendering not found in %s. R2N2 renderings will "
|
||||
"%s not found in %s. R2N2 renderings will "
|
||||
"be skipped when returning models."
|
||||
) % (r2n2_dir)
|
||||
) % (views_rel_path, r2n2_dir)
|
||||
warnings.warn(msg)
|
||||
|
||||
self.return_voxels = return_voxels
|
||||
# Check if the folder containing voxel coordinates is included in r2n2_dir.
|
||||
if not path.isdir(path.join(r2n2_dir, "ShapeNetVox32")):
|
||||
if not path.isdir(path.join(r2n2_dir, voxels_rel_path)):
|
||||
self.return_voxels = False
|
||||
msg = (
|
||||
"ShapeNetVox32 not found in %s. Voxel coordinates will "
|
||||
"%s not found in %s. Voxel coordinates will "
|
||||
"be skipped when returning models."
|
||||
) % (r2n2_dir)
|
||||
) % (voxels_rel_path, r2n2_dir)
|
||||
warnings.warn(msg)
|
||||
|
||||
synset_set = set()
|
||||
@ -230,8 +248,11 @@ class R2N2(ShapeNetBase):
|
||||
model_path = path.join(
|
||||
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
|
||||
)
|
||||
model["verts"], faces, _ = load_obj(model_path)
|
||||
model["faces"] = faces.verts_idx
|
||||
|
||||
verts, faces, textures = self._load_mesh(model_path)
|
||||
model["verts"] = verts
|
||||
model["faces"] = faces
|
||||
model["textures"] = textures
|
||||
model["label"] = self.synset_dict[model["synset_id"]]
|
||||
|
||||
model["images"] = None
|
||||
@ -240,7 +261,7 @@ class R2N2(ShapeNetBase):
|
||||
if self.return_images:
|
||||
rendering_path = path.join(
|
||||
self.r2n2_dir,
|
||||
"ShapeNetRendering",
|
||||
self.views_rel_path,
|
||||
model["synset_id"],
|
||||
model["model_id"],
|
||||
"rendering",
|
||||
@ -284,10 +305,11 @@ class R2N2(ShapeNetBase):
|
||||
model["K"] = K.expand(len(model_views), 4, 4)
|
||||
|
||||
voxels_list = []
|
||||
|
||||
# Read voxels if required.
|
||||
voxel_path = path.join(
|
||||
self.r2n2_dir,
|
||||
"ShapeNetVox32",
|
||||
self.voxels_rel_path,
|
||||
model["synset_id"],
|
||||
model["model_id"],
|
||||
"model.binvox",
|
||||
|
@ -8,7 +8,6 @@ from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||
from pytorch3d.io import load_obj
|
||||
|
||||
|
||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||
@ -21,7 +20,14 @@ class ShapeNetCore(ShapeNetBase):
|
||||
https://www.shapenet.org/.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, synsets=None, version: int = 1):
|
||||
def __init__(
|
||||
self,
|
||||
data_dir,
|
||||
synsets=None,
|
||||
version: int = 1,
|
||||
load_textures: bool = True,
|
||||
texture_resolution: int = 4,
|
||||
):
|
||||
"""
|
||||
Store each object's synset id and models id from data_dir.
|
||||
|
||||
@ -38,10 +44,17 @@ class ShapeNetCore(ShapeNetBase):
|
||||
respectively. You can combine the categories manually if needed.
|
||||
Version 2 doesn't have 02858304(boat) or 02834778(bicycle) compared to
|
||||
version 1.
|
||||
|
||||
load_textures: Boolean indicating whether textures should loaded for the model.
|
||||
Textures will be of type TexturesAtlas i.e. a texture map per face.
|
||||
texture_resolution: Int specifying the resolution of the texture map per face
|
||||
created using the textures in the obj file. A
|
||||
(texture_resolution, texture_resolution, 3) map is created per face.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapenet_dir = data_dir
|
||||
self.load_textures = load_textures
|
||||
self.texture_resolution = texture_resolution
|
||||
|
||||
if version not in [1, 2]:
|
||||
raise ValueError("Version number must be either 1 or 2.")
|
||||
self.model_dir = "model.obj" if version == 1 else "models/model_normalized.obj"
|
||||
@ -133,7 +146,9 @@ class ShapeNetCore(ShapeNetBase):
|
||||
model_path = path.join(
|
||||
self.shapenet_dir, model["synset_id"], model["model_id"], self.model_dir
|
||||
)
|
||||
model["verts"], faces, _ = load_obj(model_path)
|
||||
model["faces"] = faces.verts_idx
|
||||
verts, faces, textures = self._load_mesh(model_path)
|
||||
model["verts"] = verts
|
||||
model["faces"] = faces
|
||||
model["textures"] = textures
|
||||
model["label"] = self.synset_dict[model["synset_id"]]
|
||||
return model
|
||||
|
@ -1,11 +1,10 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from os import path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.io import load_objs_as_meshes
|
||||
from pytorch3d.io import load_obj
|
||||
from pytorch3d.renderer import (
|
||||
FoVPerspectiveCameras,
|
||||
HardPhongShader,
|
||||
@ -16,6 +15,8 @@ from pytorch3d.renderer import (
|
||||
TexturesVertex,
|
||||
)
|
||||
|
||||
from .utils import collate_batched_meshes
|
||||
|
||||
|
||||
class ShapeNetBase(torch.utils.data.Dataset):
|
||||
"""
|
||||
@ -35,6 +36,8 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
self.synset_num_models = {}
|
||||
self.shapenet_dir = ""
|
||||
self.model_dir = "model.obj"
|
||||
self.load_textures = True
|
||||
self.texture_resolution = 4
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
@ -74,6 +77,27 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
model["model_id"] = self.model_ids[idx]
|
||||
return model
|
||||
|
||||
def _load_mesh(self, model_path) -> Tuple:
|
||||
verts, faces, aux = load_obj(
|
||||
model_path,
|
||||
create_texture_atlas=self.load_textures,
|
||||
load_textures=self.load_textures,
|
||||
texture_atlas_size=self.texture_resolution,
|
||||
)
|
||||
if self.load_textures:
|
||||
textures = aux.texture_atlas
|
||||
# Some meshes don't have textures. In this case
|
||||
# create a white texture map
|
||||
if textures is None:
|
||||
textures = verts.new_ones(
|
||||
faces.verts_idx.shape[0],
|
||||
self.texture_resolution,
|
||||
self.texture_resolution,
|
||||
3,
|
||||
)
|
||||
|
||||
return verts, faces.verts_idx, textures
|
||||
|
||||
def render(
|
||||
self,
|
||||
model_ids: Optional[List[str]] = None,
|
||||
@ -112,19 +136,15 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
Batch of rendered images of shape (N, H, W, 3).
|
||||
"""
|
||||
idxs = self._handle_render_inputs(model_ids, categories, sample_nums, idxs)
|
||||
paths = [
|
||||
path.join(
|
||||
self.shapenet_dir,
|
||||
self.synset_ids[idx],
|
||||
self.model_ids[idx],
|
||||
self.model_dir,
|
||||
# Use the getitem method which loads mesh + texture
|
||||
models = [self[idx] for idx in idxs]
|
||||
meshes = collate_batched_meshes(models)["mesh"]
|
||||
if meshes.textures is None:
|
||||
meshes.textures = TexturesVertex(
|
||||
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
|
||||
)
|
||||
for idx in idxs
|
||||
]
|
||||
meshes = load_objs_as_meshes(paths, device=device, load_textures=False)
|
||||
meshes.textures = TexturesVertex(
|
||||
verts_features=torch.ones_like(meshes.verts_padded(), device=device)
|
||||
)
|
||||
|
||||
meshes = meshes.to(device)
|
||||
cameras = kwargs.get("cameras", FoVPerspectiveCameras()).to(device)
|
||||
if len(cameras) != 1 and len(cameras) % len(meshes) != 0:
|
||||
raise ValueError("Mismatch between batch dims of cameras and meshes.")
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import Dict, List
|
||||
|
||||
from pytorch3d.renderer.mesh import TexturesAtlas
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
@ -28,7 +29,15 @@ def collate_batched_meshes(batch: List[Dict]):
|
||||
|
||||
collated_dict["mesh"] = None
|
||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||
|
||||
textures = None
|
||||
if "textures" in collated_dict:
|
||||
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
||||
|
||||
collated_dict["mesh"] = Meshes(
|
||||
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
||||
verts=collated_dict["verts"],
|
||||
faces=collated_dict["faces"],
|
||||
textures=textures,
|
||||
)
|
||||
|
||||
return collated_dict
|
||||
|
@ -253,7 +253,7 @@ def load_objs_as_meshes(
|
||||
tex = None
|
||||
if create_texture_atlas:
|
||||
# TexturesAtlas type
|
||||
tex = TexturesAtlas(atlas=[aux.texture_atlas])
|
||||
tex = TexturesAtlas(atlas=[aux.texture_atlas.to(device)])
|
||||
else:
|
||||
# TexturesUV type
|
||||
tex_maps = aux.texture_images
|
||||
@ -477,18 +477,20 @@ def _load_obj(
|
||||
face_material_names = np.array(material_names)[idx] # (F,)
|
||||
face_material_names[idx == -1] = ""
|
||||
|
||||
# Get the uv coords for each vert in each face
|
||||
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
|
||||
texture_atlas = None
|
||||
if len(verts_uvs) > 0:
|
||||
# Get the uv coords for each vert in each face
|
||||
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
|
||||
|
||||
# Construct the atlas.
|
||||
texture_atlas = make_mesh_texture_atlas(
|
||||
material_colors,
|
||||
texture_images,
|
||||
face_material_names,
|
||||
faces_verts_uvs,
|
||||
texture_atlas_size,
|
||||
texture_wrap,
|
||||
)
|
||||
# Construct the atlas.
|
||||
texture_atlas = make_mesh_texture_atlas(
|
||||
material_colors,
|
||||
texture_images,
|
||||
face_material_names,
|
||||
faces_verts_uvs,
|
||||
texture_atlas_size,
|
||||
texture_wrap,
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Mtl file does not exist: {f_mtl}")
|
||||
elif len(material_names) > 0:
|
||||
|
@ -33,6 +33,8 @@ from torch.utils.data import DataLoader
|
||||
R2N2_PATH = None
|
||||
SHAPENET_PATH = None
|
||||
SPLITS_PATH = None
|
||||
VOXELS_REL_PATH = "ShapeNetVox"
|
||||
|
||||
|
||||
DEBUG = False
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
@ -69,7 +71,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
r2n2_dataset = R2N2(
|
||||
"test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
|
||||
"test",
|
||||
SHAPENET_PATH,
|
||||
R2N2_PATH,
|
||||
SPLITS_PATH,
|
||||
return_voxels=True,
|
||||
voxels_rel_path=VOXELS_REL_PATH,
|
||||
)
|
||||
|
||||
# Check total number of objects in the dataset is correct.
|
||||
@ -133,7 +140,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
r2n2_dataset = R2N2(
|
||||
"val", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
|
||||
"val",
|
||||
SHAPENET_PATH,
|
||||
R2N2_PATH,
|
||||
SPLITS_PATH,
|
||||
return_voxels=True,
|
||||
voxels_rel_path=VOXELS_REL_PATH,
|
||||
)
|
||||
|
||||
# Randomly retrieve several objects from the dataset and collate them.
|
||||
@ -362,7 +374,12 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Load dataset in the train split with only a single view returned for each model.
|
||||
r2n2_dataset = R2N2(
|
||||
"train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH, return_voxels=True
|
||||
"train",
|
||||
SHAPENET_PATH,
|
||||
R2N2_PATH,
|
||||
SPLITS_PATH,
|
||||
return_voxels=True,
|
||||
voxels_rel_path=VOXELS_REL_PATH,
|
||||
)
|
||||
r2n2_model = r2n2_dataset[6, [5]]
|
||||
vox_render = render_cubified_voxels(r2n2_model["voxels"], device=device)
|
||||
|
@ -22,6 +22,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
# Set the SHAPENET_PATH to the local path to the dataset
|
||||
SHAPENET_PATH = None
|
||||
VERSION = 1
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
# All saved images have prefix DEBUG_
|
||||
DEBUG = False
|
||||
@ -55,7 +56,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
|
||||
|
||||
# Load ShapeNetCore without specifying any particular categories.
|
||||
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
|
||||
shapenet_dataset = ShapeNetCore(SHAPENET_PATH, version=VERSION)
|
||||
|
||||
# Count the number of grandchildren directories (which should be equal to
|
||||
# the total number of objects in the dataset) by walking through the given
|
||||
|
Loading…
x
Reference in New Issue
Block a user