collate_batched_meshes for datasets

Summary: Adding collate_batched_meshes for datasets.utils: takes in a list of dictionaries and merge them into one dictionary (while adding a merged mesh to the dictionary).

Reviewed By: nikhilaravi

Differential Revision: D22180404

fbshipit-source-id: f811f9a140f09638f355ad5739bffa6ee415819f
This commit is contained in:
Luya Gao
2020-07-14 14:52:21 -07:00
committed by Facebook GitHub Bot
parent 22f2963cf1
commit 22d8c3337a
3 changed files with 70 additions and 1 deletions

View File

@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .shapenet import ShapeNetCore
from .utils import collate_batched_meshes
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -0,0 +1,35 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Dict, List
from pytorch3d.structures import Meshes
def collate_batched_meshes(batch: List[Dict]):
"""
Take a list of objects in the form of dictionaries and merge them
into a single dictionary. This function can be used with a Dataset
object to create a torch.utils.data.Dataloader which directly
returns Meshes objects.
TODO: Add support for textures.
Args:
batch: List of dictionaries containing information about objects
in the dataset.
Returns:
collated_dict: Dictionary of collated lists. If batch contains both
verts and faces, a collated mesh batch is also returned.
"""
if batch is None or len(batch) == 0:
return None
collated_dict = {}
for k in batch[0].keys():
collated_dict[k] = [d[k] for d in batch]
collated_dict["mesh"] = None
if {"verts", "faces"}.issubset(collated_dict.keys()):
collated_dict["mesh"] = Meshes(
verts=collated_dict["verts"], faces=collated_dict["faces"]
)
return collated_dict