collate_batched_meshes for datasets

Summary: Adding collate_batched_meshes for datasets.utils: takes in a list of dictionaries and merge them into one dictionary (while adding a merged mesh to the dictionary).

Reviewed By: nikhilaravi

Differential Revision: D22180404

fbshipit-source-id: f811f9a140f09638f355ad5739bffa6ee415819f
This commit is contained in:
Luya Gao
2020-07-14 14:52:21 -07:00
committed by Facebook GitHub Bot
parent 22f2963cf1
commit 22d8c3337a
3 changed files with 70 additions and 1 deletions

View File

@@ -10,13 +10,14 @@ import numpy as np
import torch
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.datasets import ShapeNetCore
from pytorch3d.datasets import ShapeNetCore, collate_batched_meshes
from pytorch3d.renderer import (
OpenGLPerspectiveCameras,
PointLights,
RasterizationSettings,
look_at_view_transform,
)
from torch.utils.data import DataLoader
# Set the SHAPENET_PATH to the local path to the dataset
@@ -110,6 +111,38 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
]
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
def test_collate_models(self):
"""
Test collate_batched_meshes returns items of the correct shapes and types.
Check that when collate_batched_meshes is passed to Dataloader, batches of
the correct shapes and types are returned.
"""
# Load ShapeNetCore without specifying any particular categories.
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
# Randomly retrieve several objects from the dataset.
rand_idxs = torch.randint(len(shapenet_dataset), (6,))
rand_objs = [shapenet_dataset[idx] for idx in rand_idxs]
# Collate the randomly selected objects
collated_meshes = collate_batched_meshes(rand_objs)
verts, faces = (collated_meshes["verts"], collated_meshes["faces"])
self.assertEqual(len(verts), 6)
self.assertEqual(len(faces), 6)
# Pass the custom collate_fn function to DataLoader and check elements
# in batch have the correct shape.
batch_size = 12
shapenet_core_loader = DataLoader(
shapenet_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes
)
it = iter(shapenet_core_loader)
object_batch = next(it)
self.assertEqual(len(object_batch["synset_id"]), batch_size)
self.assertEqual(len(object_batch["model_id"]), batch_size)
self.assertEqual(len(object_batch["label"]), batch_size)
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
def test_catch_render_arg_errors(self):
"""
Test rendering ShapeNetCore with invalid model_ids, categories or indices,