mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Test R2N2 loads correct numbers of instances
Summary: Sample/Get all views at the loading phase instead of returning phase; Load only views from the split instead of all 24 views; Test the numbers of views loaded are correct for each category. Reviewed By: nikhilaravi Differential Revision: D22631414 fbshipit-source-id: 1c5ce99fe2bdf6618c1aa0b69bb6899473376bc2
This commit is contained in:
parent
7cb9d8ea86
commit
483e538dae
@ -11,6 +11,7 @@ import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||
from pytorch3d.io import load_obj
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||
@ -21,7 +22,8 @@ class R2N2(ShapeNetBase):
|
||||
This class loads the R2N2 dataset from a given directory into a Dataset object.
|
||||
The R2N2 dataset contains 13 categories that are a subset of the ShapeNetCore v.1
|
||||
dataset. The R2N2 dataset also contains its own 24 renderings of each object and
|
||||
voxelized models.
|
||||
voxelized models. Most of the models have all 24 views in the same split, but there
|
||||
are eight of them that divide their views between train and test splits.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -40,13 +42,13 @@ class R2N2(ShapeNetBase):
|
||||
shapenet_dir (path): Path to ShapeNet core v1.
|
||||
r2n2_dir (path): Path to the R2N2 dataset.
|
||||
splits_file (path): File containing the train/val/test splits.
|
||||
return_all_views (bool): Indicator of whether or not to return all 24 views. If set
|
||||
to False, one of the 24 views would be randomly selected and returned.
|
||||
return_all_views (bool): Indicator of whether or not to load all the views in
|
||||
the split. If set to False, one of the views in the split will be randomly
|
||||
selected and loaded.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapenet_dir = shapenet_dir
|
||||
self.r2n2_dir = r2n2_dir
|
||||
self.return_all_views = return_all_views
|
||||
# Examine if split is valid.
|
||||
if split not in ["train", "val", "test"]:
|
||||
raise ValueError("split has to be one of (train, val, test).")
|
||||
@ -73,6 +75,10 @@ class R2N2(ShapeNetBase):
|
||||
warnings.warn(msg)
|
||||
|
||||
synset_set = set()
|
||||
# Store lists of views of each model in a list.
|
||||
self.views_per_model_list = []
|
||||
# Store tuples of synset label and total number of views in each category in a list.
|
||||
synset_num_instances = []
|
||||
for synset in split_dict.keys():
|
||||
# Examine if the given synset is present in the ShapeNetCore dataset
|
||||
# and is also part of the standard R2N2 dataset.
|
||||
@ -88,9 +94,10 @@ class R2N2(ShapeNetBase):
|
||||
continue
|
||||
|
||||
synset_set.add(synset)
|
||||
self.synset_starts[synset] = len(self.synset_ids)
|
||||
models = split_dict[synset].keys()
|
||||
for model in models:
|
||||
self.synset_start_idxs[synset] = len(self.synset_ids)
|
||||
# Start counting total number of views in the current category.
|
||||
synset_view_count = 0
|
||||
for model in split_dict[synset]:
|
||||
# Examine if the given model is present in the ShapeNetCore path.
|
||||
shapenet_path = path.join(shapenet_dir, synset, model)
|
||||
if not path.isdir(shapenet_path):
|
||||
@ -103,13 +110,28 @@ class R2N2(ShapeNetBase):
|
||||
continue
|
||||
self.synset_ids.append(synset)
|
||||
self.model_ids.append(model)
|
||||
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
|
||||
|
||||
model_views = split_dict[synset][model]
|
||||
# Randomly select a view index if return_all_views set to False.
|
||||
if not return_all_views:
|
||||
rand_idx = torch.randint(len(model_views), (1,))
|
||||
model_views = [model_views[rand_idx]]
|
||||
self.views_per_model_list.append(model_views)
|
||||
synset_view_count += len(model_views)
|
||||
synset_num_instances.append((self.synset_dict[synset], synset_view_count))
|
||||
model_count = len(self.synset_ids) - self.synset_start_idxs[synset]
|
||||
self.synset_num_models[synset] = model_count
|
||||
headers = ["category", "#instances"]
|
||||
synset_num_instances.append(("total", sum(n for _, n in synset_num_instances)))
|
||||
print(
|
||||
tabulate(synset_num_instances, headers, numalign="left", stralign="center")
|
||||
)
|
||||
|
||||
# Examine if all the synsets in the standard R2N2 mapping are present.
|
||||
# Update self.synset_inv so that it only includes the loaded categories.
|
||||
synset_not_present = [
|
||||
self.synset_inv.pop(self.synset_dict[synset])
|
||||
for synset in self.synset_dict.keys()
|
||||
for synset in self.synset_dict
|
||||
if synset not in synset_set
|
||||
]
|
||||
if len(synset_not_present) > 0:
|
||||
@ -126,8 +148,9 @@ class R2N2(ShapeNetBase):
|
||||
Args:
|
||||
model_idx: The idx of the model to be retrieved in the dataset.
|
||||
view_idx: List of indices of the view to be returned. Each index needs to be
|
||||
between 0 and 23, inclusive. If an invalid index is supplied, view_idx will be
|
||||
ignored and views will be sampled according to self.return_all_views.
|
||||
contained in the loaded split (always between 0 and 23, inclusive). If
|
||||
an invalid index is supplied, view_idx will be ignored and all the loaded
|
||||
views will be returned.
|
||||
|
||||
Returns:
|
||||
dictionary with following keys:
|
||||
@ -139,8 +162,31 @@ class R2N2(ShapeNetBase):
|
||||
- images: FloatTensor of shape (V, H, W, C), where V is number of views
|
||||
returned. Returns a batch of the renderings of the models from the R2N2 dataset.
|
||||
"""
|
||||
if type(model_idx) is tuple:
|
||||
if isinstance(model_idx, tuple):
|
||||
model_idx, view_idxs = model_idx
|
||||
if view_idxs is not None:
|
||||
if isinstance(view_idxs, int):
|
||||
view_idxs = [view_idxs]
|
||||
if not isinstance(view_idxs, list) and not torch.is_tensor(view_idxs):
|
||||
raise TypeError(
|
||||
"view_idxs is of type %s but it needs to be a list."
|
||||
% type(view_idxs)
|
||||
)
|
||||
|
||||
model_views = self.views_per_model_list[model_idx]
|
||||
if view_idxs is not None and any(
|
||||
idx not in self.views_per_model_list[model_idx] for idx in view_idxs
|
||||
):
|
||||
msg = """At least one of the indices in view_idxs is not available.
|
||||
Specified view of the model needs to be contained in the
|
||||
loaded split. If return_all_views is set to False, only one
|
||||
random view is loaded. Try accessing the specified view(s)
|
||||
after loading the dataset with self.return_all_views set to True.
|
||||
Now returning all view(s) in the loaded dataset."""
|
||||
warnings.warn(msg)
|
||||
elif view_idxs is not None:
|
||||
model_views = view_idxs
|
||||
|
||||
model = self._get_item_ids(model_idx)
|
||||
model_path = path.join(
|
||||
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
|
||||
@ -152,19 +198,6 @@ class R2N2(ShapeNetBase):
|
||||
model["images"] = None
|
||||
# Retrieve R2N2's renderings if required.
|
||||
if self.return_images:
|
||||
ranges = (
|
||||
range(24) if self.return_all_views else torch.randint(24, (1,)).tolist()
|
||||
)
|
||||
if view_idxs is not None and any(idx < 0 or idx > 23 for idx in view_idxs):
|
||||
msg = (
|
||||
"One of the indicies in view_idxs is out of range. "
|
||||
"Index needs to be between 0 and 23, inclusive. "
|
||||
"Now sampling according to self.return_all_views."
|
||||
)
|
||||
warnings.warn(msg)
|
||||
elif view_idxs is not None:
|
||||
ranges = view_idxs
|
||||
|
||||
rendering_path = path.join(
|
||||
self.r2n2_dir,
|
||||
"ShapeNetRendering",
|
||||
@ -174,7 +207,7 @@ class R2N2(ShapeNetBase):
|
||||
)
|
||||
|
||||
images = []
|
||||
for i in ranges:
|
||||
for i in model_views:
|
||||
# Read image.
|
||||
image_path = path.join(rendering_path, "%02d.png" % i)
|
||||
raw_img = Image.open(image_path)
|
||||
|
@ -100,7 +100,7 @@ class ShapeNetCore(ShapeNetBase):
|
||||
# Each grandchildren directory of data_dir contains an object, and the name
|
||||
# of the directory is the object's model_id.
|
||||
for synset in synset_set:
|
||||
self.synset_starts[synset] = len(self.synset_ids)
|
||||
self.synset_start_idxs[synset] = len(self.synset_ids)
|
||||
for model in os.listdir(path.join(data_dir, synset)):
|
||||
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
|
||||
msg = (
|
||||
@ -111,7 +111,8 @@ class ShapeNetCore(ShapeNetBase):
|
||||
continue
|
||||
self.synset_ids.append(synset)
|
||||
self.model_ids.append(model)
|
||||
self.synset_lens[synset] = len(self.synset_ids) - self.synset_starts[synset]
|
||||
model_count = len(self.synset_ids) - self.synset_start_idxs[synset]
|
||||
self.synset_num_models[synset] = model_count
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""
|
||||
|
@ -31,8 +31,8 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
self.synset_ids = []
|
||||
self.model_ids = []
|
||||
self.synset_inv = {}
|
||||
self.synset_starts = {}
|
||||
self.synset_lens = {}
|
||||
self.synset_start_idxs = {}
|
||||
self.synset_num_models = {}
|
||||
self.shapenet_dir = ""
|
||||
self.model_dir = "model.obj"
|
||||
|
||||
@ -227,9 +227,9 @@ class ShapeNetBase(torch.utils.data.Dataset):
|
||||
category: category synset of the category to be sampled from. If not
|
||||
specified, sample from all models in the loaded dataset.
|
||||
"""
|
||||
start = self.synset_starts[category] if category is not None else 0
|
||||
start = self.synset_start_idxs[category] if category is not None else 0
|
||||
range_len = (
|
||||
self.synset_lens[category] if category is not None else self.__len__()
|
||||
self.synset_num_models[category] if category is not None else self.__len__()
|
||||
)
|
||||
replacement = sample_num > range_len
|
||||
sampled_idxs = (
|
||||
|
@ -17,6 +17,7 @@ def collate_batched_meshes(batch: List[Dict]):
|
||||
Args:
|
||||
batch: List of dictionaries containing information about objects
|
||||
in the dataset.
|
||||
|
||||
Returns:
|
||||
collated_dict: Dictionary of collated lists. If batch contains both
|
||||
verts and faces, a collated mesh batch is also returned.
|
||||
|
@ -59,20 +59,30 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
Test the loaded train split of R2N2 return items of the correct shapes and types. Also
|
||||
check the first image returned is correct.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
r2n2_dataset = R2N2("train", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
# Load dataset in the test split.
|
||||
r2n2_dataset = R2N2("test", SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
|
||||
# Check total number of objects in the dataset is correct.
|
||||
with open(SPLITS_PATH) as splits:
|
||||
split_dict = json.load(splits)["train"]
|
||||
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
|
||||
split_dict = json.load(splits)["test"]
|
||||
model_nums = [len(split_dict[synset]) for synset in split_dict]
|
||||
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
||||
|
||||
# Randomly retrieve an object from the dataset.
|
||||
rand_idx = torch.randint(len(r2n2_dataset), (1,))
|
||||
rand_obj = r2n2_dataset[rand_idx]
|
||||
# Check the numbers of loaded instances for each category are correct.
|
||||
for synset in split_dict:
|
||||
split_synset_nums = sum(
|
||||
len(split_dict[synset][model]) for model in split_dict[synset]
|
||||
)
|
||||
idx_start = r2n2_dataset.synset_start_idxs[synset]
|
||||
idx_end = idx_start + r2n2_dataset.synset_num_models[synset]
|
||||
synset_views_list = r2n2_dataset.views_per_model_list[idx_start:idx_end]
|
||||
loaded_synset_views = sum(len(views) for views in synset_views_list)
|
||||
self.assertEqual(loaded_synset_views, split_synset_nums)
|
||||
|
||||
# Retrieve an object from the dataset.
|
||||
r2n2_obj = r2n2_dataset[39]
|
||||
# Check that verts and faces returned by __getitem__ have the correct shapes and types.
|
||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||
verts, faces = r2n2_obj["verts"], r2n2_obj["faces"]
|
||||
self.assertTrue(verts.dtype == torch.float32)
|
||||
self.assertTrue(faces.dtype == torch.int64)
|
||||
self.assertEqual(verts.ndim, 2)
|
||||
@ -81,11 +91,17 @@ class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(faces.shape[-1], 3)
|
||||
|
||||
# Check that image batch returned by __getitem__ has the correct shape.
|
||||
self.assertEqual(rand_obj["images"].shape[0], 24)
|
||||
self.assertEqual(rand_obj["images"].shape[1], 137)
|
||||
self.assertEqual(rand_obj["images"].shape[2], 137)
|
||||
self.assertEqual(rand_obj["images"].shape[-1], 3)
|
||||
self.assertEqual(r2n2_dataset[rand_idx, [21]]["images"].shape[0], 1)
|
||||
self.assertEqual(r2n2_obj["images"].shape[0], 24)
|
||||
self.assertEqual(r2n2_obj["images"].shape[1], 137)
|
||||
self.assertEqual(r2n2_obj["images"].shape[2], 137)
|
||||
self.assertEqual(r2n2_obj["images"].shape[-1], 3)
|
||||
self.assertEqual(r2n2_dataset[39, [21]]["images"].shape[0], 1)
|
||||
self.assertEqual(r2n2_dataset[39, torch.tensor([12, 21])]["images"].shape[0], 2)
|
||||
|
||||
# Check models with total view counts less than 24 return image batches
|
||||
# of the correct shapes.
|
||||
self.assertEqual(r2n2_dataset[635]["images"].shape[0], 5)
|
||||
self.assertEqual(r2n2_dataset[8369]["images"].shape[0], 10)
|
||||
|
||||
def test_collate_models(self):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user