Type safety fixes

Summary: Pyre expects Mapping for ** operator.

Reviewed By: bottler

Differential Revision: D35288632

fbshipit-source-id: 34d6f26ad912b3a5046f440922bb6ed2fd86f533
This commit is contained in:
Roman Shapovalov 2022-04-01 04:24:46 -07:00 committed by Facebook GitHub Bot
parent 24260130ce
commit a999fc22ee
2 changed files with 9 additions and 3 deletions

View File

@ -16,11 +16,13 @@ from dataclasses import dataclass, field, fields
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any,
ClassVar, ClassVar,
Dict, Dict,
Iterable, Iterable,
Iterator, Iterator,
List, List,
Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
@ -42,7 +44,7 @@ from . import types
@dataclass @dataclass
class FrameData: class FrameData(Mapping[str, Any]):
""" """
A type of the elements returned by indexing the dataset object. A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof; It can represent both individual frames and batches of thereof;
@ -137,13 +139,16 @@ class FrameData:
return self.to(device=torch.device("cuda")) return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions # the following functions make sure **frame_data can be passed to functions
def keys(self): def __iter__(self):
for f in fields(self): for f in fields(self):
yield f.name yield f.name
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)
def __len__(self):
return len(fields(self))
@classmethod @classmethod
def collate(cls, batch): def collate(cls, batch):
""" """

View File

@ -10,6 +10,7 @@ from typing import NamedTuple, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from pytorch3d.structures import Pointclouds
from .rasterize_points import rasterize_points from .rasterize_points import rasterize_points
@ -75,7 +76,7 @@ class PointsRasterizer(nn.Module):
self.cameras = cameras self.cameras = cameras
self.raster_settings = raster_settings self.raster_settings = raster_settings
def transform(self, point_clouds, **kwargs) -> torch.Tensor: def transform(self, point_clouds, **kwargs) -> Pointclouds:
""" """
Args: Args:
point_clouds: a set of point clouds point_clouds: a set of point clouds