Join points as batch

Summary: Function to join a list of pointclouds as a batch similar to the corresponding function for Meshes.

Reviewed By: bottler

Differential Revision: D33145906

fbshipit-source-id: 160639ebb5065e4fae1a1aa43117172719f3871b
This commit is contained in:
Nikhila Ravi
2021-12-21 04:43:09 -08:00
committed by Facebook GitHub Bot
parent eb2bbf8433
commit 262c1bfcd4
2 changed files with 102 additions and 1 deletions

View File

@@ -1178,3 +1178,40 @@ class Pointclouds:
coord_inside = (points_packed >= box[:, 0]) * (points_packed <= box[:, 1])
return coord_inside.all(dim=-1)
def join_pointclouds_as_batch(pointclouds: Sequence[Pointclouds]):
"""
Merge a list of Pointclouds objects into a single batched Pointclouds
object. All pointclouds must be on the same device.
Args:
batch: List of Pointclouds objects each with batch dim [b1, b2, ..., bN]
Returns:
pointcloud: Poinclouds object with all input pointclouds collated into
a single object with batch dim = sum(b1, b2, ..., bN)
"""
if isinstance(pointclouds, Pointclouds) or not isinstance(pointclouds, Sequence):
raise ValueError("Wrong first argument to join_points_as_batch.")
device = pointclouds[0].device
if not all(p.device == device for p in pointclouds):
raise ValueError("Pointclouds must all be on the same device")
kwargs = {}
for field in ("points", "normals", "features"):
field_list = [getattr(p, field + "_list")() for p in pointclouds]
if None in field_list:
if field == "points":
raise ValueError("Pointclouds cannot have their points set to None!")
if not all(f is None for f in field_list):
raise ValueError(
f"Pointclouds in the batch have some fields '{field}'"
+ " defined and some set to None."
)
field_list = None
else:
field_list = [p for points in field_list for p in points]
kwargs[field] = field_list
return Pointclouds(**kwargs)