diff --git a/pytorch3d/implicitron/dataset/implicitron_dataset.py b/pytorch3d/implicitron/dataset/implicitron_dataset.py index 88142b4e..4ff2da7d 100644 --- a/pytorch3d/implicitron/dataset/implicitron_dataset.py +++ b/pytorch3d/implicitron/dataset/implicitron_dataset.py @@ -16,11 +16,13 @@ from dataclasses import dataclass, field, fields from itertools import islice from pathlib import Path from typing import ( + Any, ClassVar, Dict, Iterable, Iterator, List, + Mapping, Optional, Sequence, Tuple, @@ -42,7 +44,7 @@ from . import types @dataclass -class FrameData: +class FrameData(Mapping[str, Any]): """ A type of the elements returned by indexing the dataset object. It can represent both individual frames and batches of thereof; @@ -137,13 +139,16 @@ class FrameData: return self.to(device=torch.device("cuda")) # the following functions make sure **frame_data can be passed to functions - def keys(self): + def __iter__(self): for f in fields(self): yield f.name def __getitem__(self, key): return getattr(self, key) + def __len__(self): + return len(fields(self)) + @classmethod def collate(cls, batch): """ diff --git a/pytorch3d/renderer/points/rasterizer.py b/pytorch3d/renderer/points/rasterizer.py index 73233ac7..cfa65da2 100644 --- a/pytorch3d/renderer/points/rasterizer.py +++ b/pytorch3d/renderer/points/rasterizer.py @@ -10,6 +10,7 @@ from typing import NamedTuple, Optional, Tuple, Union import torch import torch.nn as nn +from pytorch3d.structures import Pointclouds from .rasterize_points import rasterize_points @@ -75,7 +76,7 @@ class PointsRasterizer(nn.Module): self.cameras = cameras self.raster_settings = raster_settings - def transform(self, point_clouds, **kwargs) -> torch.Tensor: + def transform(self, point_clouds, **kwargs) -> Pointclouds: """ Args: point_clouds: a set of point clouds