mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
collate_batched_meshes for datasets
Summary: Adding collate_batched_meshes for datasets.utils: takes in a list of dictionaries and merge them into one dictionary (while adding a merged mesh to the dictionary). Reviewed By: nikhilaravi Differential Revision: D22180404 fbshipit-source-id: f811f9a140f09638f355ad5739bffa6ee415819f
This commit is contained in:
parent
22f2963cf1
commit
22d8c3337a
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .shapenet import ShapeNetCore
|
||||
from .utils import collate_batched_meshes
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
35
pytorch3d/datasets/utils.py
Normal file
35
pytorch3d/datasets/utils.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
def collate_batched_meshes(batch: List[Dict]):
|
||||
"""
|
||||
Take a list of objects in the form of dictionaries and merge them
|
||||
into a single dictionary. This function can be used with a Dataset
|
||||
object to create a torch.utils.data.Dataloader which directly
|
||||
returns Meshes objects.
|
||||
TODO: Add support for textures.
|
||||
|
||||
Args:
|
||||
batch: List of dictionaries containing information about objects
|
||||
in the dataset.
|
||||
Returns:
|
||||
collated_dict: Dictionary of collated lists. If batch contains both
|
||||
verts and faces, a collated mesh batch is also returned.
|
||||
"""
|
||||
if batch is None or len(batch) == 0:
|
||||
return None
|
||||
collated_dict = {}
|
||||
for k in batch[0].keys():
|
||||
collated_dict[k] = [d[k] for d in batch]
|
||||
|
||||
collated_dict["mesh"] = None
|
||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||
collated_dict["mesh"] = Meshes(
|
||||
verts=collated_dict["verts"], faces=collated_dict["faces"]
|
||||
)
|
||||
|
||||
return collated_dict
|
@ -10,13 +10,14 @@ import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from PIL import Image
|
||||
from pytorch3d.datasets import ShapeNetCore
|
||||
from pytorch3d.datasets import ShapeNetCore, collate_batched_meshes
|
||||
from pytorch3d.renderer import (
|
||||
OpenGLPerspectiveCameras,
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
look_at_view_transform,
|
||||
)
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
# Set the SHAPENET_PATH to the local path to the dataset
|
||||
@ -110,6 +111,38 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase):
|
||||
]
|
||||
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
|
||||
|
||||
def test_collate_models(self):
|
||||
"""
|
||||
Test collate_batched_meshes returns items of the correct shapes and types.
|
||||
Check that when collate_batched_meshes is passed to Dataloader, batches of
|
||||
the correct shapes and types are returned.
|
||||
"""
|
||||
# Load ShapeNetCore without specifying any particular categories.
|
||||
shapenet_dataset = ShapeNetCore(SHAPENET_PATH)
|
||||
# Randomly retrieve several objects from the dataset.
|
||||
rand_idxs = torch.randint(len(shapenet_dataset), (6,))
|
||||
rand_objs = [shapenet_dataset[idx] for idx in rand_idxs]
|
||||
|
||||
# Collate the randomly selected objects
|
||||
collated_meshes = collate_batched_meshes(rand_objs)
|
||||
verts, faces = (collated_meshes["verts"], collated_meshes["faces"])
|
||||
self.assertEqual(len(verts), 6)
|
||||
self.assertEqual(len(faces), 6)
|
||||
|
||||
# Pass the custom collate_fn function to DataLoader and check elements
|
||||
# in batch have the correct shape.
|
||||
batch_size = 12
|
||||
shapenet_core_loader = DataLoader(
|
||||
shapenet_dataset, batch_size=batch_size, collate_fn=collate_batched_meshes
|
||||
)
|
||||
it = iter(shapenet_core_loader)
|
||||
object_batch = next(it)
|
||||
self.assertEqual(len(object_batch["synset_id"]), batch_size)
|
||||
self.assertEqual(len(object_batch["model_id"]), batch_size)
|
||||
self.assertEqual(len(object_batch["label"]), batch_size)
|
||||
self.assertEqual(object_batch["mesh"].verts_padded().shape[0], batch_size)
|
||||
self.assertEqual(object_batch["mesh"].faces_padded().shape[0], batch_size)
|
||||
|
||||
def test_catch_render_arg_errors(self):
|
||||
"""
|
||||
Test rendering ShapeNetCore with invalid model_ids, categories or indices,
|
||||
|
Loading…
x
Reference in New Issue
Block a user