mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Adding renderer for ShapeNetBase
Summary: Adding a renderer to ShapeNetCore (Note that the lights are currently turned off for the test; will investigate why lighting causes instability in rendering) Reviewed By: nikhilaravi Differential Revision: D22102673 fbshipit-source-id: a704756a1e93b61d5a879f0e5ee14ebcb0df49d7
This commit is contained in:
parent
09c1762939
commit
358e211cde
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from .shapenet_core import ShapeNetCore
|
from .shapenet_core import ShapeNetCore
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,15 +5,16 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from os import path
|
from os import path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import torch
|
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||||
from pytorch3d.io import load_obj
|
from pytorch3d.io import load_obj
|
||||||
|
|
||||||
|
|
||||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
class ShapeNetCore(torch.utils.data.Dataset):
|
class ShapeNetCore(ShapeNetBase):
|
||||||
"""
|
"""
|
||||||
This class loads ShapeNetCore from a given directory into a Dataset object.
|
This class loads ShapeNetCore from a given directory into a Dataset object.
|
||||||
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
|
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
|
||||||
@ -23,6 +24,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
|||||||
def __init__(self, data_dir, synsets=None, version: int = 1):
|
def __init__(self, data_dir, synsets=None, version: int = 1):
|
||||||
"""
|
"""
|
||||||
Store each object's synset id and models id from data_dir.
|
Store each object's synset id and models id from data_dir.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_dir: Path to ShapeNetCore data.
|
data_dir: Path to ShapeNetCore data.
|
||||||
synsets: List of synset categories to load from ShapeNetCore in the form of
|
synsets: List of synset categories to load from ShapeNetCore in the form of
|
||||||
@ -38,6 +40,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
|||||||
version 1.
|
version 1.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
super().__init__()
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
if version not in [1, 2]:
|
if version not in [1, 2]:
|
||||||
raise ValueError("Version number must be either 1 or 2.")
|
raise ValueError("Version number must be either 1 or 2.")
|
||||||
@ -48,7 +51,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
|||||||
with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
|
with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
|
||||||
self.synset_dict = json.load(read_dict)
|
self.synset_dict = json.load(read_dict)
|
||||||
# Inverse dicitonary mapping synset labels to corresponding offsets.
|
# Inverse dicitonary mapping synset labels to corresponding offsets.
|
||||||
synset_inv = {label: offset for offset, label in self.synset_dict.items()}
|
self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}
|
||||||
|
|
||||||
# If categories are specified, check if each category is in the form of either
|
# If categories are specified, check if each category is in the form of either
|
||||||
# synset offset or synset label, and if the category exists in the given directory.
|
# synset offset or synset label, and if the category exists in the given directory.
|
||||||
@ -60,62 +63,61 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
|||||||
path.isdir(path.join(data_dir, synset))
|
path.isdir(path.join(data_dir, synset))
|
||||||
):
|
):
|
||||||
synset_set.add(synset)
|
synset_set.add(synset)
|
||||||
elif (synset in synset_inv.keys()) and (
|
elif (synset in self.synset_inv.keys()) and (
|
||||||
(path.isdir(path.join(data_dir, synset_inv[synset])))
|
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
|
||||||
):
|
):
|
||||||
synset_set.add(synset_inv[synset])
|
synset_set.add(self.synset_inv[synset])
|
||||||
else:
|
else:
|
||||||
msg = """Synset category %s either not part of ShapeNetCore dataset
|
msg = (
|
||||||
or cannot be found in %s.""" % (
|
"Synset category %s either not part of ShapeNetCore dataset "
|
||||||
synset,
|
"or cannot be found in %s."
|
||||||
data_dir,
|
) % (synset, data_dir)
|
||||||
)
|
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
# If no category is given, load every category in the given directory.
|
# If no category is given, load every category in the given directory.
|
||||||
|
# Ignore synset folders not included in the official mapping.
|
||||||
else:
|
else:
|
||||||
synset_set = {
|
synset_set = {
|
||||||
synset
|
synset
|
||||||
for synset in os.listdir(data_dir)
|
for synset in os.listdir(data_dir)
|
||||||
if path.isdir(path.join(data_dir, synset))
|
if path.isdir(path.join(data_dir, synset))
|
||||||
|
and synset in self.synset_dict
|
||||||
}
|
}
|
||||||
for synset in synset_set:
|
|
||||||
if synset not in self.synset_dict.keys():
|
# Check if there are any categories in the official mapping that are not loaded.
|
||||||
msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
|
# Update self.synset_inv so that it only includes the loaded categories.
|
||||||
but not found in %s.""" % (
|
synset_not_present = set(self.synset_dict.keys()).difference(synset_set)
|
||||||
synset,
|
[self.synset_inv.pop(self.synset_dict[synset]) for synset in synset_not_present]
|
||||||
self.synset_dict[synset],
|
|
||||||
version,
|
if len(synset_not_present) > 0:
|
||||||
data_dir,
|
msg = (
|
||||||
)
|
"The following categories are included in ShapeNetCore ver.%d's "
|
||||||
warnings.warn(msg)
|
"official mapping but not found in the dataset location %s: %s"
|
||||||
|
""
|
||||||
|
) % (version, data_dir, ", ".join(synset_not_present))
|
||||||
|
warnings.warn(msg)
|
||||||
|
|
||||||
# Extract model_id of each object from directory names.
|
# Extract model_id of each object from directory names.
|
||||||
# Each grandchildren directory of data_dir contains an object, and the name
|
# Each grandchildren directory of data_dir contains an object, and the name
|
||||||
# of the directory is the object's model_id.
|
# of the directory is the object's model_id.
|
||||||
self.synset_ids = []
|
|
||||||
self.model_ids = []
|
|
||||||
for synset in synset_set:
|
for synset in synset_set:
|
||||||
for model in os.listdir(path.join(data_dir, synset)):
|
for model in os.listdir(path.join(data_dir, synset)):
|
||||||
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
|
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
|
||||||
msg = """ Object file not found in the model directory %s
|
msg = (
|
||||||
under synset directory %s.""" % (
|
"Object file not found in the model directory %s "
|
||||||
model,
|
"under synset directory %s."
|
||||||
synset,
|
) % (model, synset)
|
||||||
)
|
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
else:
|
continue
|
||||||
self.synset_ids.append(synset)
|
self.synset_ids.append(synset)
|
||||||
self.model_ids.append(model)
|
self.model_ids.append(model)
|
||||||
|
|
||||||
def __len__(self):
|
def __getitem__(self, idx: int) -> Dict:
|
||||||
"""
|
|
||||||
Return number of total models in shapenet core.
|
|
||||||
"""
|
|
||||||
return len(self.model_ids)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
"""
|
"""
|
||||||
Read a model by the given index.
|
Read a model by the given index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: The idx of the model to be retrieved in the dataset.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dictionary with following keys:
|
dictionary with following keys:
|
||||||
- verts: FloatTensor of shape (V, 3).
|
- verts: FloatTensor of shape (V, 3).
|
||||||
@ -124,9 +126,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
|||||||
- model_id (str): model id
|
- model_id (str): model id
|
||||||
- label (str): synset label.
|
- label (str): synset label.
|
||||||
"""
|
"""
|
||||||
model = {}
|
model = self._get_item_ids(idx)
|
||||||
model["synset_id"] = self.synset_ids[idx]
|
|
||||||
model["model_id"] = self.model_ids[idx]
|
|
||||||
model_path = path.join(
|
model_path = path.join(
|
||||||
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
|
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
|
||||||
)
|
)
|
||||||
|
107
pytorch3d/datasets/shapenet_base.py
Normal file
107
pytorch3d/datasets/shapenet_base.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
HardPhongShader,
|
||||||
|
MeshRasterizer,
|
||||||
|
MeshRenderer,
|
||||||
|
OpenGLPerspectiveCameras,
|
||||||
|
PointLights,
|
||||||
|
RasterizationSettings,
|
||||||
|
)
|
||||||
|
from pytorch3d.structures import Meshes, Textures
|
||||||
|
|
||||||
|
|
||||||
|
class ShapeNetBase(torch.utils.data.Dataset):
|
||||||
|
"""
|
||||||
|
'ShapeNetBase' implements a base Dataset for ShapeNet and R2N2 with helper methods.
|
||||||
|
It is not intended to be used on its own as a Dataset for a Dataloader. Both __init__
|
||||||
|
and __getitem__ need to be implemented.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
Set up lists of synset_ids and model_ids.
|
||||||
|
"""
|
||||||
|
self.synset_ids = []
|
||||||
|
self.model_ids = []
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Return number of total models in the loaded dataset.
|
||||||
|
"""
|
||||||
|
return len(self.model_ids)
|
||||||
|
|
||||||
|
def __getitem__(self, idx) -> Dict:
|
||||||
|
"""
|
||||||
|
Read a model by the given index. Need to be implemented for every child class
|
||||||
|
of ShapeNetBase.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: The idx of the model to be retrieved in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary containing information about the model.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"__getitem__ should be implemented in the child class of ShapeNetBase"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_item_ids(self, idx) -> Dict:
|
||||||
|
"""
|
||||||
|
Read a model by the given index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: The idx of the model to be retrieved in the dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dictionary with following keys:
|
||||||
|
- synset_id (str): synset id
|
||||||
|
- model_id (str): model id
|
||||||
|
"""
|
||||||
|
model = {}
|
||||||
|
model["synset_id"] = self.synset_ids[idx]
|
||||||
|
model["model_id"] = self.model_ids[idx]
|
||||||
|
return model
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Renders a model by the given index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: The index of model to be rendered in the dataset.
|
||||||
|
shader_type: select shading. Valid options include HardPhongShader (default),
|
||||||
|
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
|
||||||
|
SoftSilhouetteShader.
|
||||||
|
device: torch.device on which the tensors should be located.
|
||||||
|
**kwargs: Accepts any of the kwargs that the renderer supports.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered image of shape (1, H, W, 3).
|
||||||
|
"""
|
||||||
|
|
||||||
|
model = self.__getitem__(idx)
|
||||||
|
verts, faces = model["verts"], model["faces"]
|
||||||
|
verts_rgb = torch.ones_like(verts, device=device)[None]
|
||||||
|
mesh = Meshes(
|
||||||
|
verts=[verts.to(device)],
|
||||||
|
faces=[faces.to(device)],
|
||||||
|
textures=Textures(verts_rgb=verts_rgb.to(device)),
|
||||||
|
)
|
||||||
|
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
|
||||||
|
renderer = MeshRenderer(
|
||||||
|
rasterizer=MeshRasterizer(
|
||||||
|
cameras=cameras,
|
||||||
|
raster_settings=kwargs.get("raster_settings", RasterizationSettings()),
|
||||||
|
),
|
||||||
|
shader=shader_type(
|
||||||
|
device=device,
|
||||||
|
cameras=cameras,
|
||||||
|
lights=kwargs.get("lights", PointLights()).to(device),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return renderer(mesh)
|
BIN
tests/data/test_shapenet_core_render_piano.png
Normal file
BIN
tests/data/test_shapenet_core_render_piano.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.2 KiB |
@ -6,17 +6,32 @@ import os
|
|||||||
import random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin, load_rgb_image
|
||||||
|
from PIL import Image
|
||||||
from pytorch3d.datasets import ShapeNetCore
|
from pytorch3d.datasets import ShapeNetCore
|
||||||
|
from pytorch3d.renderer import (
|
||||||
|
OpenGLPerspectiveCameras,
|
||||||
|
PointLights,
|
||||||
|
RasterizationSettings,
|
||||||
|
look_at_view_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
SHAPENET_PATH = None
|
SHAPENET_PATH = None
|
||||||
|
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||||
|
# All saved images have prefix DEBUG_
|
||||||
|
DEBUG = False
|
||||||
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||||
|
|
||||||
|
|
||||||
class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||||
def test_load_shapenet_core(self):
|
def test_load_shapenet_core(self):
|
||||||
|
# Setup
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
# The ShapeNet dataset is not provided in the repo.
|
# The ShapeNet dataset is not provided in the repo.
|
||||||
# Download this separately and update the `shapenet_path`
|
# Download this separately and update the `shapenet_path`
|
||||||
@ -31,7 +46,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
|||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Try load ShapeNetCore with an invalid version number and catch error.
|
# Try loading ShapeNetCore with an invalid version number and catch error.
|
||||||
with self.assertRaises(ValueError) as err:
|
with self.assertRaises(ValueError) as err:
|
||||||
ShapeNetCore(SHAPENET_PATH, version=3)
|
ShapeNetCore(SHAPENET_PATH, version=3)
|
||||||
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
|
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
|
||||||
@ -93,3 +108,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
|||||||
for offset in subset_offsets
|
for offset in subset_offsets
|
||||||
]
|
]
|
||||||
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
|
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
|
||||||
|
|
||||||
|
# Render the first image in the piano category.
|
||||||
|
R, T = look_at_view_transform(1.0, 1.0, 90)
|
||||||
|
piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"])
|
||||||
|
|
||||||
|
cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
|
||||||
|
raster_settings = RasterizationSettings(image_size=512)
|
||||||
|
lights = PointLights(
|
||||||
|
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
|
||||||
|
# TODO: debug the source of the discrepancy in two images when rendering on GPU.
|
||||||
|
diffuse_color=((0, 0, 0),),
|
||||||
|
specular_color=((0, 0, 0),),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
images = piano_dataset.render(
|
||||||
|
0,
|
||||||
|
device=device,
|
||||||
|
cameras=cameras,
|
||||||
|
raster_settings=raster_settings,
|
||||||
|
lights=lights,
|
||||||
|
)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / "DEBUG_shapenet_core_render_piano.png"
|
||||||
|
)
|
||||||
|
image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR)
|
||||||
|
self.assertClose(rgb, image_ref, atol=0.05)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user