From 22d8c3337a2bab3ae4d5cc00de536e3e7668c816 Mon Sep 17 00:00:00 2001 From: Luya Gao Date: Tue, 14 Jul 2020 14:52:21 -0700 Subject: [PATCH] collate_batched_meshes for datasets Summary: Adding collate_batched_meshes for datasets.utils: takes in a list of dictionaries and merge them into one dictionary (while adding a merged mesh to the dictionary). Reviewed By: nikhilaravi Differential Revision: D22180404 fbshipit-source-id: f811f9a140f09638f355ad5739bffa6ee415819f --- pytorch3d/datasets/__init__.py | 1 + pytorch3d/datasets/utils.py | 35 ++++++++++++++++++++++++++++++++++ tests/test_shapenet_core.py | 35 +++++++++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 pytorch3d/datasets/utils.py diff --git a/pytorch3d/datasets/__init__.py b/pytorch3d/datasets/__init__.py index 243247e5..04da090a 100644 --- a/pytorch3d/datasets/__init__.py +++ b/pytorch3d/datasets/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from .shapenet import ShapeNetCore +from .utils import collate_batched_meshes __all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/datasets/utils.py b/pytorch3d/datasets/utils.py new file mode 100644 index 00000000..81616b51 --- /dev/null +++ b/pytorch3d/datasets/utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Dict, List + +from pytorch3d.structures import Meshes + + +def collate_batched_meshes(batch: List[Dict]): + """ + Take a list of objects in the form of dictionaries and merge them + into a single dictionary. This function can be used with a Dataset + object to create a torch.utils.data.Dataloader which directly + returns Meshes objects. + TODO: Add support for textures. + + Args: + batch: List of dictionaries containing information about objects + in the dataset. + Returns: + collated_dict: Dictionary of collated lists. If batch contains both + verts and faces, a collated mesh batch is also returned. + """ + if batch is None or len(batch) == 0: + return None + collated_dict = {} + for k in batch[0].keys(): + collated_dict[k] = [d[k] for d in batch] + + collated_dict["mesh"] = None + if {"verts", "faces"}.issubset(collated_dict.keys()): + collated_dict["mesh"] = Meshes( + verts=collated_dict["verts"], faces=collated_dict["faces"] + ) + + return collated_dict diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py index a1fc664c..4b225674 100644 --- a/tests/test_shapenet_core.py +++ b/tests/test_shapenet_core.py @@ -10,13 +10,14 @@ import numpy as np import torch from common_testing import TestCaseMixin, load_rgb_image from PIL import Image -from pytorch3d.datasets import ShapeNetCore +from pytorch3d.datasets import ShapeNetCore, collate_batched_meshes from pytorch3d.renderer import ( OpenGLPerspectiveCameras, PointLights, RasterizationSettings, look_at_view_transform, ) +from torch.utils.data import DataLoader # Set the SHAPENET_PATH to the local path to the dataset @@ -110,6 +111,38 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): ] self.assertEqual(len(shapenet_subset), sum(subset_model_nums)) + def test_collate_models(self): + """ + Test collate_batched_meshes returns items of the correct shapes and types. + Check that when collate_batched_meshes is passed to Dataloader, batches of + the correct shapes and types are returned. + """ + # Load ShapeNetCore without specifying any particular categories. + shapenet_dataset = ShapeNetCore(SHAPENET_PATH) + # Randomly retrieve several objects from the dataset. + rand_idxs = torch.randint(len(shapenet_dataset), (6,)) + rand_objs = [shapenet_dataset[idx] for idx in rand_idxs] + + # Collate the randomly selected objects + collated_meshes = collate_batched_meshes(rand_objs) + verts, faces = (collated_meshes["verts"], collated_meshes["faces"]) + self.assertEqual(len(verts), 6) + self.assertEqual(len(faces), 6) + + # Pass the custom collate_fn function to DataLoader and check elements + # in batch have the correct shape. + batch_size = 12 + shapenet_core_loader = DataLoader( + shapenet_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes + ) + it = iter(shapenet_core_loader) + object_batch = next(it) + self.assertEqual(len(object_batch["synset_id"]), batch_size) + self.assertEqual(len(object_batch["model_id"]), batch_size) + self.assertEqual(len(object_batch["label"]), batch_size) + self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size) + self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size) + def test_catch_render_arg_errors(self): """ Test rendering ShapeNetCore with invalid model_ids, categories or indices,