mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
R2N2 skeleton
Summary: Skeleton of R2N2 that for now only returns verts and faces extracted from ShapeNetCore v1. Reviewed By: nikhilaravi Differential Revision: D22203656 fbshipit-source-id: 00db6ac76bfdb76fdbc77a2087c34a3f0ff01e6a
This commit is contained in:
parent
22d8c3337a
commit
49b4ce1acc
@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .r2n2 import R2N2
|
||||
from .shapenet import ShapeNetCore
|
||||
from .utils import collate_batched_meshes
|
||||
|
||||
|
6
pytorch3d/datasets/r2n2/__init__.py
Normal file
6
pytorch3d/datasets/r2n2/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .r2n2 import R2N2
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
118
pytorch3d/datasets/r2n2/r2n2.py
Normal file
118
pytorch3d/datasets/r2n2/r2n2.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import json
|
||||
import warnings
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from pytorch3d.datasets.shapenet_base import ShapeNetBase
|
||||
from pytorch3d.io import load_obj
|
||||
|
||||
|
||||
SYNSET_DICT_DIR = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
class R2N2(ShapeNetBase):
|
||||
"""
|
||||
This class loads the R2N2 dataset from a given directory into a Dataset object.
|
||||
The R2N2 dataset contains 13 categories that are a subset of the ShapeNetCore v.1
|
||||
dataset. The R2N2 dataset also contains its own 24 renderings of each object and
|
||||
voxelized models.
|
||||
"""
|
||||
|
||||
def __init__(self, split, shapenet_dir, r2n2_dir, splits_file):
|
||||
"""
|
||||
Store each object's synset id and models id the given directories.
|
||||
Args:
|
||||
split (str): One of (train, val, test).
|
||||
shapenet_dir (path): Path to ShapeNet core v1.
|
||||
r2n2_dir (path): Path to the R2N2 dataset.
|
||||
splits_file (path): File containing the train/val/test splits.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapenet_dir = shapenet_dir
|
||||
self.r2n2_dir = r2n2_dir
|
||||
# Examine if split is valid.
|
||||
if split not in ["train", "val", "test"]:
|
||||
raise ValueError("split has to be one of (train, val, test).")
|
||||
# Synset dictionary mapping synset offsets in R2N2 to corresponding labels.
|
||||
with open(
|
||||
path.join(SYNSET_DICT_DIR, "r2n2_synset_dict.json"), "r"
|
||||
) as read_dict:
|
||||
self.synset_dict = json.load(read_dict)
|
||||
# Inverse dicitonary mapping synset labels to corresponding offsets.
|
||||
self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}
|
||||
|
||||
# Store synset and model ids of objects mentioned in the splits_file.
|
||||
with open(splits_file) as splits:
|
||||
split_dict = json.load(splits)[split]
|
||||
|
||||
synset_set = set()
|
||||
for synset in split_dict.keys():
|
||||
# Examine if the given synset is present in the ShapeNetCore dataset
|
||||
# and is also part of the standard R2N2 dataset.
|
||||
if not (
|
||||
path.isdir(path.join(shapenet_dir, synset))
|
||||
and synset in self.synset_dict
|
||||
):
|
||||
msg = (
|
||||
"Synset category %s from the splits file is either not "
|
||||
"present in %s or not part of the standard R2N2 dataset."
|
||||
) % (synset, shapenet_dir)
|
||||
warnings.warn(msg)
|
||||
continue
|
||||
|
||||
synset_set.add(synset)
|
||||
models = split_dict[synset].keys()
|
||||
for model in models:
|
||||
# Examine if the given model is present in the ShapeNetCore path.
|
||||
shapenet_path = path.join(shapenet_dir, synset, model)
|
||||
if not path.isdir(shapenet_path):
|
||||
msg = "Model %s from category %s is not present in %s." % (
|
||||
model,
|
||||
synset,
|
||||
shapenet_dir,
|
||||
)
|
||||
warnings.warn(msg)
|
||||
continue
|
||||
self.synset_ids.append(synset)
|
||||
self.model_ids.append(model)
|
||||
|
||||
# Examine if all the synsets in the standard R2N2 mapping are present.
|
||||
# Update self.synset_inv so that it only includes the loaded categories.
|
||||
synset_not_present = [
|
||||
self.synset_inv.pop(self.synset_dict[synset])
|
||||
for synset in self.synset_dict.keys()
|
||||
if synset not in synset_set
|
||||
]
|
||||
if len(synset_not_present) > 0:
|
||||
msg = (
|
||||
"The following categories are included in R2N2's"
|
||||
"official mapping but not found in the dataset location %s: %s"
|
||||
) % (shapenet_dir, ", ".join(synset_not_present))
|
||||
warnings.warn(msg)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict:
|
||||
"""
|
||||
Read a model by the given index.
|
||||
|
||||
Args:
|
||||
idx: The idx of the model to be retrieved in the dataset.
|
||||
|
||||
Returns:
|
||||
dictionary with following keys:
|
||||
- verts: FloatTensor of shape (V, 3).
|
||||
- faces: faces.verts_idx, LongTensor of shape (F, 3).
|
||||
- synset_id (str): synset id.
|
||||
- model_id (str): model id.
|
||||
- label (str): synset label.
|
||||
"""
|
||||
model = self._get_item_ids(idx)
|
||||
model_path = path.join(
|
||||
self.shapenet_dir, model["synset_id"], model["model_id"], "model.obj"
|
||||
)
|
||||
model["verts"], faces, _ = load_obj(model_path)
|
||||
model["faces"] = faces.verts_idx
|
||||
model["label"] = self.synset_dict[model["synset_id"]]
|
||||
return model
|
15
pytorch3d/datasets/r2n2/r2n2_synset_dict.json
Normal file
15
pytorch3d/datasets/r2n2/r2n2_synset_dict.json
Normal file
@ -0,0 +1,15 @@
|
||||
{
|
||||
"04256520": "sofa",
|
||||
"02933112": "cabinet",
|
||||
"02828884": "bench",
|
||||
"03001627": "chair",
|
||||
"03211117": "display",
|
||||
"04090263": "rifle",
|
||||
"03691459": "loudspeaker",
|
||||
"03636649": "lamp",
|
||||
"04401088": "telephone",
|
||||
"02691156": "airplane",
|
||||
"04379243": "table",
|
||||
"02958343": "car",
|
||||
"04530566": "watercraft"
|
||||
}
|
111
tests/test_r2n2.py
Normal file
111
tests/test_r2n2.py
Normal file
@ -0,0 +1,111 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
"""
|
||||
Sanity checks for loading R2N2.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.datasets import R2N2, collate_batched_meshes
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
# Set these paths in order to run the tests.
|
||||
R2N2_PATH = None
|
||||
SHAPENET_PATH = None
|
||||
SPLITS_PATH = None
|
||||
|
||||
|
||||
class TestR2N2(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
"""
|
||||
Check if the data paths are given otherwise skip tests.
|
||||
"""
|
||||
if SHAPENET_PATH is None or not os.path.exists(SHAPENET_PATH):
|
||||
url = "https://www.shapenet.org/"
|
||||
msg = (
|
||||
"ShapeNet data not found, download from %s, update "
|
||||
"SHAPENET_PATH at the top of the file, and rerun."
|
||||
)
|
||||
self.skipTest(msg % url)
|
||||
if R2N2_PATH is None or not os.path.exists(R2N2_PATH):
|
||||
url = "http://3d-r2n2.stanford.edu/"
|
||||
msg = (
|
||||
"R2N2 data not found, download from %s, update "
|
||||
"R2N2_PATH at the top of the file, and rerun."
|
||||
)
|
||||
self.skipTest(msg % url)
|
||||
if SPLITS_PATH is None or not os.path.exists(SPLITS_PATH):
|
||||
msg = """Splits file not found, update SPLITS_PATH at the top
|
||||
of the file, and rerun."""
|
||||
self.skipTest(msg)
|
||||
|
||||
def test_load_R2N2(self):
|
||||
"""
|
||||
Test loading the train split of R2N2. Check the loaded dataset return items
|
||||
of the correct shapes and types.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
split = "train"
|
||||
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
|
||||
# Check total number of objects in the dataset is correct.
|
||||
with open(SPLITS_PATH) as splits:
|
||||
split_dict = json.load(splits)[split]
|
||||
model_nums = [len(split_dict[synset].keys()) for synset in split_dict.keys()]
|
||||
self.assertEqual(len(r2n2_dataset), sum(model_nums))
|
||||
|
||||
# Randomly retrieve an object from the dataset.
|
||||
rand_obj = r2n2_dataset[torch.randint(len(r2n2_dataset), (1,))]
|
||||
# Check that data type and shape of the item returned by __getitem__ are correct.
|
||||
verts, faces = rand_obj["verts"], rand_obj["faces"]
|
||||
self.assertTrue(verts.dtype == torch.float32)
|
||||
self.assertTrue(faces.dtype == torch.int64)
|
||||
self.assertEqual(verts.ndim, 2)
|
||||
self.assertEqual(verts.shape[-1], 3)
|
||||
self.assertEqual(faces.ndim, 2)
|
||||
self.assertEqual(faces.shape[-1], 3)
|
||||
|
||||
def test_collate_models(self):
|
||||
"""
|
||||
Test collate_batched_meshes returns items of the correct shapes and types.
|
||||
Check that when collate_batched_meshes is passed to Dataloader, batches of
|
||||
the correct shapes and types are returned.
|
||||
"""
|
||||
# Load dataset in the train split.
|
||||
split = "train"
|
||||
r2n2_dataset = R2N2(split, SHAPENET_PATH, R2N2_PATH, SPLITS_PATH)
|
||||
|
||||
# Randomly retrieve several objects from the dataset and collate them.
|
||||
collated_meshes = collate_batched_meshes(
|
||||
[r2n2_dataset[idx] for idx in torch.randint(len(r2n2_dataset), (6,))]
|
||||
)
|
||||
# Check the collated verts and faces have the correct shapes.
|
||||
verts, faces = collated_meshes["verts"], collated_meshes["faces"]
|
||||
self.assertEqual(len(verts), 6)
|
||||
self.assertEqual(len(faces), 6)
|
||||
self.assertEqual(verts[0].shape[-1], 3)
|
||||
self.assertEqual(faces[0].shape[-1], 3)
|
||||
|
||||
# Check the collated mesh has the correct shape.
|
||||
mesh = collated_meshes["mesh"]
|
||||
self.assertEqual(mesh.verts_padded().shape[0], 6)
|
||||
self.assertEqual(mesh.verts_padded().shape[-1], 3)
|
||||
self.assertEqual(mesh.faces_padded().shape[0], 6)
|
||||
self.assertEqual(mesh.faces_padded().shape[-1], 3)
|
||||
|
||||
# Pass the custom collate_fn function to DataLoader and check elements
|
||||
# in batch have the correct shape.
|
||||
batch_size = 12
|
||||
r2n2_loader = DataLoader(
|
||||
r2n2_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes
|
||||
)
|
||||
it = iter(r2n2_loader)
|
||||
object_batch = next(it)
|
||||
self.assertEqual(len(object_batch["synset_id"]), batch_size)
|
||||
self.assertEqual(len(object_batch["model_id"]), batch_size)
|
||||
self.assertEqual(len(object_batch["label"]), batch_size)
|
||||
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
||||
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
Loading…
x
Reference in New Issue
Block a user