mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Return R2N2 renderings
Summary: R2N2 returns R2N2's own renderings of ShapeNetCore models. Reviewed By: nikhilaravi Differential Revision: D22266988 fbshipit-source-id: 36e67bd06c6459773e6e5f654259166b579be36a
This commit is contained in:
parent
5636eb6152
commit
dc08c30583
@ -4,8 +4,11 @@ import json
|
|||||||
import warnings
|
import warnings
|
||||||
from os import path
|
from os import path
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||||
from pytorch3d.io import load_obj
|
from pytorch3d.io import load_obj
|
||||||
|
|
||||||
@ -21,18 +24,29 @@ class R2N2(ShapeNetBase):
|
|||||||
voxelized models.
|
voxelized models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
|
def __init__(
|
||||||
|
self,
|
||||||
|
split: str,
|
||||||
|
shapenet_dir,
|
||||||
|
r2n2_dir,
|
||||||
|
splits_file,
|
||||||
|
return_all_views: bool = True,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Store each object's synset id and models id the given directories.
|
Store each object's synset id and models id the given directories.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
split (str): One of (train, val, test).
|
split (str): One of (train, val, test).
|
||||||
shapenet_dir (path): Path to ShapeNet core v1.
|
shapenet_dir (path): Path to ShapeNet core v1.
|
||||||
r2n2_dir (path): Path to the R2N2 dataset.
|
r2n2_dir (path): Path to the R2N2 dataset.
|
||||||
splits_file (path): File containing the train/val/test splits.
|
splits_file (path): File containing the train/val/test splits.
|
||||||
|
return_all_views (bool): Indicator of whether or not to return all 24 views. If set
|
||||||
|
to False, one of the 24 views would be randomly selected and returned.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shapenet_dir = shapenet_dir
|
self.shapenet_dir = shapenet_dir
|
||||||
self.r2n2_dir = r2n2_dir
|
self.r2n2_dir = r2n2_dir
|
||||||
|
self.return_all_views = return_all_views
|
||||||
# Examine if split is valid.
|
# Examine if split is valid.
|
||||||
if split not in ["train", "val", "test"]:
|
if split not in ["train", "val", "test"]:
|
||||||
raise ValueError("split has to be one of (train, val, test).")
|
raise ValueError("split has to be one of (train, val, test).")
|
||||||
@ -48,6 +62,16 @@ class R2N2(ShapeNetBase):
|
|||||||
with open(splits_file) as splits:
|
with open(splits_file) as splits:
|
||||||
split_dict = json.load(splits)[split]
|
split_dict = json.load(splits)[split]
|
||||||
|
|
||||||
|
self.return_images = True
|
||||||
|
# Check if the folder containing R2N2 renderings is included in r2n2_dir.
|
||||||
|
if not path.isdir(path.join(r2n2_dir, "ShapeNetRendering")):
|
||||||
|
self.return_images = False
|
||||||
|
msg = (
|
||||||
|
"ShapeNetRendering not found in %s. R2N2 renderings will "
|
||||||
|
"be skipped when returning models."
|
||||||
|
) % (r2n2_dir)
|
||||||
|
warnings.warn(msg)
|
||||||
|
|
||||||
synset_set = set()
|
synset_set = set()
|
||||||
for synset in split_dict.keys():
|
for synset in split_dict.keys():
|
||||||
# Examine if the given synset is present in the ShapeNetCore dataset
|
# Examine if the given synset is present in the ShapeNetCore dataset
|
||||||
@ -95,12 +119,15 @@ class R2N2(ShapeNetBase):
|
|||||||
) % (shapenet_dir, ", ".join(synset_not_present))
|
) % (shapenet_dir, ", ".join(synset_not_present))
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Dict:
|
def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
|
||||||
"""
|
"""
|
||||||
Read a model by the given index.
|
Read a model by the given index.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
idx: The idx of the model to be retrieved in the dataset.
|
model_idx: The idx of the model to be retrieved in the dataset.
|
||||||
|
view_idx: List of indices of the view to be returned. Each index needs to be
|
||||||
|
between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be
|
||||||
|
ignored and views will be sampled according to self.return_all_views.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dictionary with following keys:
|
dictionary with following keys:
|
||||||
@ -109,12 +136,51 @@ class R2N2(ShapeNetBase):
|
|||||||
- synset_id (str): synset id.
|
- synset_id (str): synset id.
|
||||||
- model_id (str): model id.
|
- model_id (str): model id.
|
||||||
- label (str): synset label.
|
- label (str): synset label.
|
||||||
|
- images: FloatTensor of shape (V, H, W, C), where V is number of views
|
||||||
|
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
|
||||||
"""
|
"""
|
||||||
model = self._get_item_ids(idx)
|
if type(model_idx) is tuple:
|
||||||
|
model_idx, view_idxs = model_idx
|
||||||
|
model = self._get_item_ids(model_idx)
|
||||||
model_path = path.join(
|
model_path = path.join(
|
||||||
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
|
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
|
||||||
)
|
)
|
||||||
model["verts"], faces, _ = load_obj(model_path)
|
model["verts"], faces, _ = load_obj(model_path)
|
||||||
model["faces"] = faces.verts_idx
|
model["faces"] = faces.verts_idx
|
||||||
model["label"] = self.synset_dict[model["synset_id"]]
|
model["label"] = self.synset_dict[model["synset_id"]]
|
||||||
|
|
||||||
|
model["images"] = None
|
||||||
|
# Retrieve R2N2's renderings if required.
|
||||||
|
if self.return_images:
|
||||||
|
ranges = (
|
||||||
|
range(24) if self.return_all_views else torch.randint(24, (1,)).tolist()
|
||||||
|
)
|
||||||
|
if view_idxs is not None and any(idx < 0 or idx > 23 for idx in view_idxs):
|
||||||
|
msg = (
|
||||||
|
"One of the indicies in view_idxs is out of range. "
|
||||||
|
"Index needs to be between 0 and 23, inclusive. "
|
||||||
|
"Now sampling according to self.return_all_views."
|
||||||
|
)
|
||||||
|
warnings.warn(msg)
|
||||||
|
elif view_idxs is not None:
|
||||||
|
ranges = view_idxs
|
||||||
|
|
||||||
|
rendering_path = path.join(
|
||||||
|
self.r2n2_dir,
|
||||||
|
"ShapeNetRendering",
|
||||||
|
model["synset_id"],
|
||||||
|
model["model_id"],
|
||||||
|
"rendering",
|
||||||
|
)
|
||||||
|
|
||||||
|
images = []
|
||||||
|
for i in ranges:
|
||||||
|
# Read image.
|
||||||
|
image_path = path.join(rendering_path, "%02d.png" % i)
|
||||||
|
raw_img = Image.open(image_path)
|
||||||
|
image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
|
||||||
|
images.append(image.to(dtype=torch.float32))
|
||||||
|
|
||||||
|
model["images"] = torch.stack(images)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
from pytorch3d.structures import Meshes
|
from pytorch3d.structures import Meshes
|
||||||
|
|
||||||
|
|
||||||
@ -32,4 +33,10 @@ def collate_batched_meshes(batch: List[Dict]):
|
|||||||
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If collate_batched_meshes receives R2N2 items, stack the batches of
|
||||||
|
# views of each model into a new batch of shape (N, V, H, W, 3) where
|
||||||
|
# V is the number of views.
|
||||||
|
if "images" in collated_dict:
|
||||||
|
collated_dict["images"] = torch.stack(collated_dict["images"])
|
||||||
|
|
||||||
return collated_dict
|
return collated_dict
|
||||||
|
@ -56,7 +56,8 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_load_R2N2(self):
|
def test_load_R2N2(self):
|
||||||
"""
|
"""
|
||||||
Test the loaded train split of R2N2 return items of the correct shapes and types.
|
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
|
||||||
|
check the first image returned is correct.
|
||||||
"""
|
"""
|
||||||
# Load dataset in the train split.
|
# Load dataset in the train split.
|
||||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||||
@ -68,8 +69,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
||||||
|
|
||||||
# Randomly retrieve an object from the dataset.
|
# Randomly retrieve an object from the dataset.
|
||||||
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
|
rand_idx = torch.randint(len(r2n2_dataset), (1,))
|
||||||
# Check that data type and shape of the item returned by __getitem__ are correct.
|
rand_obj = r2n2_dataset[rand_idx]
|
||||||
|
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
|
||||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||||
self.assertTrue(verts.dtype == torch.float32)
|
self.assertTrue(verts.dtype == torch.float32)
|
||||||
self.assertTrue(faces.dtype == torch.int64)
|
self.assertTrue(faces.dtype == torch.int64)
|
||||||
@ -78,6 +80,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(faces.ndim, 2)
|
self.assertEqual(faces.ndim, 2)
|
||||||
self.assertEqual(faces.shape[-1], 3)
|
self.assertEqual(faces.shape[-1], 3)
|
||||||
|
|
||||||
|
# Check that image batch returned by __getitem__ has the correct shape.
|
||||||
|
self.assertEqual(rand_obj["images"].shape[0], 24)
|
||||||
|
self.assertEqual(rand_obj["images"].shape[1], 137)
|
||||||
|
self.assertEqual(rand_obj["images"].shape[2], 137)
|
||||||
|
self.assertEqual(rand_obj["images"].shape[-1], 3)
|
||||||
|
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
|
||||||
|
|
||||||
def test_collate_models(self):
|
def test_collate_models(self):
|
||||||
"""
|
"""
|
||||||
Test collate_batched_meshes returns items of the correct shapes and types.
|
Test collate_batched_meshes returns items of the correct shapes and types.
|
||||||
@ -118,6 +127,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(len(object_batch["label"]), batch_size)
|
self.assertEqual(len(object_batch["label"]), batch_size)
|
||||||
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
||||||
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
||||||
|
self.assertEqual(object_batch["images"].shape[0], batch_size)
|
||||||
|
|
||||||
def test_catch_render_arg_errors(self):
|
def test_catch_render_arg_errors(self):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user