mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Adding renderer for ShapeNetBase
Summary: Adding a renderer to ShapeNetCore (Note that the lights are currently turned off for the test; will investigate why lighting causes instability in rendering) Reviewed By: nikhilaravi Differential Revision: D22102673 fbshipit-source-id: a704756a1e93b61d5a879f0e5ee14ebcb0df49d7
This commit is contained in:
parent
09c1762939
commit
358e211cde
@ -1,4 +1,5 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .shapenet_core import ShapeNetCore
|
||||
|
||||
|
||||
|
@ -5,15 +5,16 @@ import os
|
||||
import warnings
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||
from pytorch3d.io import load_obj
|
||||
|
||||
|
||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class ShapeNetCore(torch.utils.data.Dataset):
|
||||
class ShapeNetCore(ShapeNetBase):
|
||||
"""
|
||||
This class loads ShapeNetCore from a given directory into a Dataset object.
|
||||
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
|
||||
@ -23,6 +24,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
||||
def __init__(self, data_dir, synsets=None, version: int = 1):
|
||||
"""
|
||||
Store each object's synset id and models id from data_dir.
|
||||
|
||||
Args:
|
||||
data_dir: Path to ShapeNetCore data.
|
||||
synsets: List of synset categories to load from ShapeNetCore in the form of
|
||||
@ -38,6 +40,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
||||
version 1.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.data_dir = data_dir
|
||||
if version not in [1, 2]:
|
||||
raise ValueError("Version number must be either 1 or 2.")
|
||||
@ -48,7 +51,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
||||
with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
|
||||
self.synset_dict = json.load(read_dict)
|
||||
# Inverse dicitonary mapping synset labels to corresponding offsets.
|
||||
synset_inv = {label: offset for offset, label in self.synset_dict.items()}
|
||||
self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}
|
||||
|
||||
# If categories are specified, check if each category is in the form of either
|
||||
# synset offset or synset label, and if the category exists in the given directory.
|
||||
@ -60,62 +63,61 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
||||
path.isdir(path.join(data_dir, synset))
|
||||
):
|
||||
synset_set.add(synset)
|
||||
elif (synset in synset_inv.keys()) and (
|
||||
(path.isdir(path.join(data_dir, synset_inv[synset])))
|
||||
elif (synset in self.synset_inv.keys()) and (
|
||||
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
|
||||
):
|
||||
synset_set.add(synset_inv[synset])
|
||||
synset_set.add(self.synset_inv[synset])
|
||||
else:
|
||||
msg = """Synset category %s either not part of ShapeNetCore dataset
|
||||
or cannot be found in %s.""" % (
|
||||
synset,
|
||||
data_dir,
|
||||
)
|
||||
msg = (
|
||||
"Synset category %s either not part of ShapeNetCore dataset "
|
||||
"or cannot be found in %s."
|
||||
) % (synset, data_dir)
|
||||
warnings.warn(msg)
|
||||
# If no category is given, load every category in the given directory.
|
||||
# Ignore synset folders not included in the official mapping.
|
||||
else:
|
||||
synset_set = {
|
||||
synset
|
||||
for synset in os.listdir(data_dir)
|
||||
if path.isdir(path.join(data_dir, synset))
|
||||
and synset in self.synset_dict
|
||||
}
|
||||
for synset in synset_set:
|
||||
if synset not in self.synset_dict.keys():
|
||||
msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
|
||||
but not found in %s.""" % (
|
||||
synset,
|
||||
self.synset_dict[synset],
|
||||
version,
|
||||
data_dir,
|
||||
)
|
||||
warnings.warn(msg)
|
||||
|
||||
# Check if there are any categories in the official mapping that are not loaded.
|
||||
# Update self.synset_inv so that it only includes the loaded categories.
|
||||
synset_not_present = set(self.synset_dict.keys()).difference(synset_set)
|
||||
[self.synset_inv.pop(self.synset_dict[synset]) for synset in synset_not_present]
|
||||
|
||||
if len(synset_not_present) > 0:
|
||||
msg = (
|
||||
"The following categories are included in ShapeNetCore ver.%d's "
|
||||
"official mapping but not found in the dataset location %s: %s"
|
||||
""
|
||||
) % (version, data_dir, ", ".join(synset_not_present))
|
||||
warnings.warn(msg)
|
||||
|
||||
# Extract model_id of each object from directory names.
|
||||
# Each grandchildren directory of data_dir contains an object, and the name
|
||||
# of the directory is the object's model_id.
|
||||
self.synset_ids = []
|
||||
self.model_ids = []
|
||||
for synset in synset_set:
|
||||
for model in os.listdir(path.join(data_dir, synset)):
|
||||
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
|
||||
msg = """ Object file not found in the model directory %s
|
||||
under synset directory %s.""" % (
|
||||
model,
|
||||
synset,
|
||||
)
|
||||
msg = (
|
||||
"Object file not found in the model directory %s "
|
||||
"under synset directory %s."
|
||||
) % (model, synset)
|
||||
warnings.warn(msg)
|
||||
else:
|
||||
self.synset_ids.append(synset)
|
||||
self.model_ids.append(model)
|
||||
continue
|
||||
self.synset_ids.append(synset)
|
||||
self.model_ids.append(model)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return number of total models in shapenet core.
|
||||
"""
|
||||
return len(self.model_ids)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""
|
||||
Read a model by the given index.
|
||||
|
||||
Args:
|
||||
idx: The idx of the model to be retrieved in the dataset.
|
||||
|
||||
Returns:
|
||||
dictionary with following keys:
|
||||
- verts: FloatTensor of shape (V, 3).
|
||||
@ -124,9 +126,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
|
||||
- model_id (str): model id
|
||||
- label (str): synset label.
|
||||
"""
|
||||
model = {}
|
||||
model["synset_id"] = self.synset_ids[idx]
|
||||
model["model_id"] = self.model_ids[idx]
|
||||
model = self._get_item_ids(idx)
|
||||
model_path = path.join(
|
||||
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
|
||||
)
|
||||
|
107
pytorch3d/datasets/shapenet_base.py
Normal file
107
pytorch3d/datasets/shapenet_base.py
Normal file
@ -0,0 +1,107 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch3d.renderer import (
|
||||
HardPhongShader,
|
||||
MeshRasterizer,
|
||||
MeshRenderer,
|
||||
OpenGLPerspectiveCameras,
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, Textures
|
||||
|
||||
|
||||
class ShapeNetBase(torch.utils.data.Dataset):
|
||||
"""
|
||||
'ShapeNetBase' implements a base Dataset for ShapeNet and R2N2 with helper methods.
|
||||
It is not intended to be used on its own as a Dataset for a Dataloader. Both __init__
|
||||
and __getitem__ need to be implemented.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Set up lists of synset_ids and model_ids.
|
||||
"""
|
||||
self.synset_ids = []
|
||||
self.model_ids = []
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Return number of total models in the loaded dataset.
|
||||
"""
|
||||
return len(self.model_ids)
|
||||
|
||||
def __getitem__(self, idx) -> Dict:
|
||||
"""
|
||||
Read a model by the given index. Need to be implemented for every child class
|
||||
of ShapeNetBase.
|
||||
|
||||
Args:
|
||||
idx: The idx of the model to be retrieved in the dataset.
|
||||
|
||||
Returns:
|
||||
dictionary containing information about the model.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"__getitem__ should be implemented in the child class of ShapeNetBase"
|
||||
)
|
||||
|
||||
def _get_item_ids(self, idx) -> Dict:
|
||||
"""
|
||||
Read a model by the given index.
|
||||
|
||||
Args:
|
||||
idx: The idx of the model to be retrieved in the dataset.
|
||||
|
||||
Returns:
|
||||
dictionary with following keys:
|
||||
- synset_id (str): synset id
|
||||
- model_id (str): model id
|
||||
"""
|
||||
model = {}
|
||||
model["synset_id"] = self.synset_ids[idx]
|
||||
model["model_id"] = self.model_ids[idx]
|
||||
return model
|
||||
|
||||
def render(
|
||||
self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Renders a model by the given index.
|
||||
|
||||
Args:
|
||||
idx: The index of model to be rendered in the dataset.
|
||||
shader_type: select shading. Valid options include HardPhongShader (default),
|
||||
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
|
||||
SoftSilhouetteShader.
|
||||
device: torch.device on which the tensors should be located.
|
||||
**kwargs: Accepts any of the kwargs that the renderer supports.
|
||||
|
||||
Returns:
|
||||
Rendered image of shape (1, H, W, 3).
|
||||
"""
|
||||
|
||||
model = self.__getitem__(idx)
|
||||
verts, faces = model["verts"], model["faces"]
|
||||
verts_rgb = torch.ones_like(verts, device=device)[None]
|
||||
mesh = Meshes(
|
||||
verts=[verts.to(device)],
|
||||
faces=[faces.to(device)],
|
||||
textures=Textures(verts_rgb=verts_rgb.to(device)),
|
||||
)
|
||||
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=cameras,
|
||||
raster_settings=kwargs.get("raster_settings", RasterizationSettings()),
|
||||
),
|
||||
shader=shader_type(
|
||||
device=device,
|
||||
cameras=cameras,
|
||||
lights=kwargs.get("lights", PointLights()).to(device),
|
||||
),
|
||||
)
|
||||
return renderer(mesh)
|
BIN
tests/data/test_shapenet_core_render_piano.png
Normal file
BIN
tests/data/test_shapenet_core_render_piano.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 3.2 KiB |
@ -6,17 +6,32 @@ import os
|
||||
import random
|
||||
import unittest
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from PIL import Image
|
||||
from pytorch3d.datasets import ShapeNetCore
|
||||
from pytorch3d.renderer import (
|
||||
OpenGLPerspectiveCameras,
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
look_at_view_transform,
|
||||
)
|
||||
|
||||
|
||||
SHAPENET_PATH = None
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
# All saved images have prefix DEBUG_
|
||||
DEBUG = False
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
def test_load_shapenet_core(self):
|
||||
# Setup
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# The ShapeNet dataset is not provided in the repo.
|
||||
# Download this separately and update the `shapenet_path`
|
||||
@ -31,7 +46,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
warnings.warn(msg)
|
||||
return True
|
||||
|
||||
# Try load ShapeNetCore with an invalid version number and catch error.
|
||||
# Try loading ShapeNetCore with an invalid version number and catch error.
|
||||
with self.assertRaises(ValueError) as err:
|
||||
ShapeNetCore(SHAPENET_PATH, version=3)
|
||||
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
|
||||
@ -93,3 +108,31 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
for offset in subset_offsets
|
||||
]
|
||||
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
|
||||
|
||||
# Render the first image in the piano category.
|
||||
R, T = look_at_view_transform(1.0, 1.0, 90)
|
||||
piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"])
|
||||
|
||||
cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
|
||||
raster_settings = RasterizationSettings(image_size=512)
|
||||
lights = PointLights(
|
||||
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
|
||||
# TODO: debug the source of the discrepancy in two images when rendering on GPU.
|
||||
diffuse_color=((0, 0, 0),),
|
||||
specular_color=((0, 0, 0),),
|
||||
device=device,
|
||||
)
|
||||
images = piano_dataset.render(
|
||||
0,
|
||||
device=device,
|
||||
cameras=cameras,
|
||||
raster_settings=raster_settings,
|
||||
lights=lights,
|
||||
)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_shapenet_core_render_piano.png"
|
||||
)
|
||||
image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
Loading…
x
Reference in New Issue
Block a user