mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Return R2N2 renderings
Summary: R2N2 returns R2N2's own renderings of ShapeNetCore models. Reviewed By: nikhilaravi Differential Revision: D22266988 fbshipit-source-id: 36e67bd06c6459773e6e5f654259166b579be36a
This commit is contained in:
parent
5636eb6152
commit
dc08c30583
@ -4,8 +4,11 @@ import json
|
||||
import warnings
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||
from pytorch3d.io import load_obj
|
||||
|
||||
@ -21,18 +24,29 @@ class R2N2(ShapeNetBase):
|
||||
voxelized models.
|
||||
"""
|
||||
|
||||
def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
shapenet_dir,
|
||||
r2n2_dir,
|
||||
splits_file,
|
||||
return_all_views: bool = True,
|
||||
):
|
||||
"""
|
||||
Store each object's synset id and models id the given directories.
|
||||
|
||||
Args:
|
||||
split (str): One of (train, val, test).
|
||||
shapenet_dir (path): Path to ShapeNet core v1.
|
||||
r2n2_dir (path): Path to the R2N2 dataset.
|
||||
splits_file (path): File containing the train/val/test splits.
|
||||
return_all_views (bool): Indicator of whether or not to return all 24 views. If set
|
||||
to False, one of the 24 views would be randomly selected and returned.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapenet_dir = shapenet_dir
|
||||
self.r2n2_dir = r2n2_dir
|
||||
self.return_all_views = return_all_views
|
||||
# Examine if split is valid.
|
||||
if split not in ["train", "val", "test"]:
|
||||
raise ValueError("split has to be one of (train, val, test).")
|
||||
@ -48,6 +62,16 @@ class R2N2(ShapeNetBase):
|
||||
with open(splits_file) as splits:
|
||||
split_dict = json.load(splits)[split]
|
||||
|
||||
self.return_images = True
|
||||
# Check if the folder containing R2N2 renderings is included in r2n2_dir.
|
||||
if not path.isdir(path.join(r2n2_dir, "ShapeNetRendering")):
|
||||
self.return_images = False
|
||||
msg = (
|
||||
"ShapeNetRendering not found in %s. R2N2 renderings will "
|
||||
"be skipped when returning models."
|
||||
) % (r2n2_dir)
|
||||
warnings.warn(msg)
|
||||
|
||||
synset_set = set()
|
||||
for synset in split_dict.keys():
|
||||
# Examine if the given synset is present in the ShapeNetCore dataset
|
||||
@ -95,12 +119,15 @@ class R2N2(ShapeNetBase):
|
||||
) % (shapenet_dir, ", ".join(synset_not_present))
|
||||
warnings.warn(msg)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
def __getitem__(self, model_idx, view_idxs: Optional[List[int]] = None) -> Dict:
|
||||
"""
|
||||
Read a model by the given index.
|
||||
|
||||
Args:
|
||||
idx: The idx of the model to be retrieved in the dataset.
|
||||
model_idx: The idx of the model to be retrieved in the dataset.
|
||||
view_idx: List of indices of the view to be returned. Each index needs to be
|
||||
between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be
|
||||
ignored and views will be sampled according to self.return_all_views.
|
||||
|
||||
Returns:
|
||||
dictionary with following keys:
|
||||
@ -109,12 +136,51 @@ class R2N2(ShapeNetBase):
|
||||
- synset_id (str): synset id.
|
||||
- model_id (str): model id.
|
||||
- label (str): synset label.
|
||||
- images: FloatTensor of shape (V, H, W, C), where V is number of views
|
||||
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
|
||||
"""
|
||||
model = self._get_item_ids(idx)
|
||||
if type(model_idx) is tuple:
|
||||
model_idx, view_idxs = model_idx
|
||||
model = self._get_item_ids(model_idx)
|
||||
model_path = path.join(
|
||||
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
|
||||
)
|
||||
model["verts"], faces, _ = load_obj(model_path)
|
||||
model["faces"] = faces.verts_idx
|
||||
model["label"] = self.synset_dict[model["synset_id"]]
|
||||
|
||||
model["images"] = None
|
||||
# Retrieve R2N2's renderings if required.
|
||||
if self.return_images:
|
||||
ranges = (
|
||||
range(24) if self.return_all_views else torch.randint(24, (1,)).tolist()
|
||||
)
|
||||
if view_idxs is not None and any(idx < 0 or idx > 23 for idx in view_idxs):
|
||||
msg = (
|
||||
"One of the indicies in view_idxs is out of range. "
|
||||
"Index needs to be between 0 and 23, inclusive. "
|
||||
"Now sampling according to self.return_all_views."
|
||||
)
|
||||
warnings.warn(msg)
|
||||
elif view_idxs is not None:
|
||||
ranges = view_idxs
|
||||
|
||||
rendering_path = path.join(
|
||||
self.r2n2_dir,
|
||||
"ShapeNetRendering",
|
||||
model["synset_id"],
|
||||
model["model_id"],
|
||||
"rendering",
|
||||
)
|
||||
|
||||
images = []
|
||||
for i in ranges:
|
||||
# Read image.
|
||||
image_path = path.join(rendering_path, "%02d.png" % i)
|
||||
raw_img = Image.open(image_path)
|
||||
image = torch.from_numpy(np.array(raw_img) / 255.0)[..., :3]
|
||||
images.append(image.to(dtype=torch.float32))
|
||||
|
||||
model["images"] = torch.stack(images)
|
||||
|
||||
return model
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
@ -32,4 +33,10 @@ def collate_batched_meshes(batch: List[Dict]):
|
||||
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
||||
)
|
||||
|
||||
# If collate_batched_meshes receives R2N2 items, stack the batches of
|
||||
# views of each model into a new batch of shape (N, V, H, W, 3) where
|
||||
# V is the number of views.
|
||||
if "images" in collated_dict:
|
||||
collated_dict["images"] = torch.stack(collated_dict["images"])
|
||||
|
||||
return collated_dict
|
||||
|
@ -56,7 +56,8 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def test_load_R2N2(self):
|
||||
"""
|
||||
Test the loaded train split of R2N2 return items of the correct shapes and types.
|
||||
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
|
||||
check the first image returned is correct.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
@ -68,8 +69,9 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
||||
|
||||
# Randomly retrieve an object from the dataset.
|
||||
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
|
||||
# Check that data type and shape of the item returned by __getitem__ are correct.
|
||||
rand_idx = torch.randint(len(r2n2_dataset), (1,))
|
||||
rand_obj = r2n2_dataset[rand_idx]
|
||||
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
|
||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||
self.assertTrue(verts.dtype == torch.float32)
|
||||
self.assertTrue(faces.dtype == torch.int64)
|
||||
@ -78,6 +80,13 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(faces.ndim, 2)
|
||||
self.assertEqual(faces.shape[-1], 3)
|
||||
|
||||
# Check that image batch returned by __getitem__ has the correct shape.
|
||||
self.assertEqual(rand_obj["images"].shape[0], 24)
|
||||
self.assertEqual(rand_obj["images"].shape[1], 137)
|
||||
self.assertEqual(rand_obj["images"].shape[2], 137)
|
||||
self.assertEqual(rand_obj["images"].shape[-1], 3)
|
||||
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
|
||||
|
||||
def test_collate_models(self):
|
||||
"""
|
||||
Test collate_batched_meshes returns items of the correct shapes and types.
|
||||
@ -118,6 +127,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(len(object_batch["label"]), batch_size)
|
||||
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
||||
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
||||
self.assertEqual(object_batch["images"].shape[0], batch_size)
|
||||
|
||||
def test_catch_render_arg_errors(self):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user