Skeleton of ShapeNetCore class

Summary: Skeleton of ShapeNetCore class that loads ShapeNet v1 from a given directory to a Dataset object. Overrides _init_, _len_, and _getitem_ from torch.utils.data.Dataset. Currently getitem returns verts, faces and id_str, where id_str is a concatenation of synset_id and obj_id. Planning on adding support for loading ShapeNet v2, retrieving textures and returning wordnet synsets (not just ids) in next diffs.

Reviewed By: nikhilaravi

Differential Revision: D21986222

fbshipit-source-id: c2c515303f1898b6c495b52cb53c74d691585326
This commit is contained in:
Luya Gao
2020-06-17 20:29:23 -07:00
committed by Facebook GitHub Bot
parent 2f6387f239
commit 9d279ba543
2 changed files with 136 additions and 0 deletions

View File

@@ -0,0 +1,74 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import os
import warnings
from os import path
import torch
from pytorch3d.io import load_obj
class ShapeNetCore(torch.utils.data.Dataset):
"""
This class loads ShapeNet v.1 from a given directory into a Dataset object.
"""
def __init__(self, data_dir):
"""
Stores each object's synset id and models id from data_dir.
Args:
data_dir (path): Path to shapenet data
"""
self.data_dir = data_dir
# List of subdirectories of data_dir each containing a category of models.
# The name of each subdirectory is the wordnet synset offset of that category.
wnsynset_list = [
wnsynset
for wnsynset in os.listdir(data_dir)
if path.isdir(path.join(data_dir, wnsynset))
]
# Extract synset_id and model_id of each object from directory names.
# Each grandchildren directory of data_dir contains an object, and the name
# of the directory is the object's model_id.
self.synset_ids = []
self.model_ids = []
for synset in wnsynset_list:
for model in os.listdir(path.join(data_dir, synset)):
if not path.exists(path.join(data_dir, synset, model, "model.obj")):
msg = """ model.obj not found in the model directory %s
under synset directory %s.""" % (
model,
synset,
)
warnings.warn(msg)
else:
self.synset_ids.append(synset)
self.model_ids.append(model)
def __len__(self):
"""
Returns # of total models in shapenet core
"""
return len(self.model_ids)
def __getitem__(self, idx):
"""
Read a model by the given index.
Returns:
dictionary with following keys:
- verts: FloatTensor of shape (V, 3).
- faces: LongTensor of shape (F, 3) which indexes into the verts tensor.
- synset_id (str): synset id
- model_id (str): model id
"""
model = {}
model["synset_id"] = self.synset_ids[idx]
model["model_id"] = self.model_ids[idx]
model_path = path.join(
self.data_dir, model["synset_id"], model["model_id"], "model.obj"
)
model["verts"], faces, _ = load_obj(model_path)
model["faces"] = faces.verts_idx
return model