Type safety fixes

Summary: Pyre expects Mapping for ** operator.

Reviewed By: bottler

Differential Revision: D35288632

fbshipit-source-id: 34d6f26ad912b3a5046f440922bb6ed2fd86f533
This commit is contained in:
Roman Shapovalov 2022-04-01 04:24:46 -07:00 committed by Facebook GitHub Bot
parent 24260130ce
commit a999fc22ee
2 changed files with 9 additions and 3 deletions

View File

@ -16,11 +16,13 @@ from dataclasses import dataclass, field, fields
from itertools import islice
from pathlib import Path
from typing import (
Any,
ClassVar,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
@ -42,7 +44,7 @@ from . import types
@dataclass
class FrameData:
class FrameData(Mapping[str, Any]):
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
@ -137,13 +139,16 @@ class FrameData:
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def keys(self):
def __iter__(self):
for f in fields(self):
yield f.name
def __getitem__(self, key):
return getattr(self, key)
def __len__(self):
return len(fields(self))
@classmethod
def collate(cls, batch):
"""

View File

@ -10,6 +10,7 @@ from typing import NamedTuple, Optional, Tuple, Union
import torch
import torch.nn as nn
from pytorch3d.structures import Pointclouds
from .rasterize_points import rasterize_points
@ -75,7 +76,7 @@ class PointsRasterizer(nn.Module):
self.cameras = cameras
self.raster_settings = raster_settings
def transform(self, point_clouds, **kwargs) -> torch.Tensor:
def transform(self, point_clouds, **kwargs) -> Pointclouds:
"""
Args:
point_clouds: a set of point clouds